From 12be9b0a2055550f4c750bcb6cd3424f6aa9cbf8 Mon Sep 17 00:00:00 2001 From: YuriHead Date: Tue, 20 Jun 2023 10:14:33 +0800 Subject: [PATCH] Update losses.py Change back to the previous loss, and do not change the loss until it is confirmed that this can work normally --- modules/losses.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/losses.py b/modules/losses.py index 0cf24e5..cd21799 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -55,8 +55,7 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): z_mask = z_mask.float() #print(logs_p) kl = logs_p - logs_q - 0.5 - # kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) - kl += 0.5 * (torch.exp(2.*logs_q)+(z_p - m_p)**2) * torch.exp(-2. * logs_p) + kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) kl = torch.sum(kl * z_mask) l = kl / torch.sum(z_mask) return l