Update losses.py

Change back to the previous loss, and do not change the loss until it is confirmed that this can work normally
This commit is contained in:
YuriHead 2023-06-20 10:14:33 +08:00 committed by GitHub
parent 43fc6ac84e
commit 12be9b0a20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 2 deletions

View File

@ -55,8 +55,7 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
z_mask = z_mask.float()
#print(logs_p)
kl = logs_p - logs_q - 0.5
# kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl += 0.5 * (torch.exp(2.*logs_q)+(z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l