diff --git a/modules/losses.py b/modules/losses.py index 0cf24e5..cd21799 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -55,8 +55,7 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): z_mask = z_mask.float() #print(logs_p) kl = logs_p - logs_q - 0.5 - # kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) - kl += 0.5 * (torch.exp(2.*logs_q)+(z_p - m_p)**2) * torch.exp(-2. * logs_p) + kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) kl = torch.sum(kl * z_mask) l = kl / torch.sum(z_mask) return l