Debug Snake

This commit is contained in:
YuriHead 2023-07-01 14:12:27 +08:00
parent 63f889572c
commit 8cf44b0a56
2 changed files with 7 additions and 3 deletions

View File

@ -85,12 +85,14 @@ class LowPassFilter1d(nn.Module):
self.register_buffer("filter", filter)
self.conv1d_block = None
if C is not None:
self.conv1d_block = (nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False), 1)
self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),]
self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1))
self.conv1d_block[0].requires_grad_(False)
#input [B, C, T]
def forward(self, x):
if self.conv1d_block[0].weight.device != x.device:
self.conv1d_block[0] = self.conv1d_block[0].to(x.device)
if self.conv1d_block is None:
_, C, _ = x.shape

View File

@ -22,13 +22,13 @@ class UpSample1d(nn.Module):
self.register_buffer("filter", filter)
self.conv_transpose1d_block = None
if C is not None:
self.conv_transpose1d_block = (nn.ConvTranspose1d(C,
self.conv_transpose1d_block = [nn.ConvTranspose1d(C,
C,
kernel_size=self.kernel_size,
stride=self.stride,
groups=C,
bias=False
), 1)
),]
self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone())
self.conv_transpose1d_block[0].requires_grad_(False)
@ -36,6 +36,8 @@ class UpSample1d(nn.Module):
# x: [B, C, T]
def forward(self, x, C=None):
if self.conv_transpose1d_block[0].weight.device != x.device:
self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device)
if self.conv_transpose1d_block is None:
if C is None:
_, C, _ = x.shape