Debug Snake
This commit is contained in:
parent
63f889572c
commit
8cf44b0a56
|
@ -85,12 +85,14 @@ class LowPassFilter1d(nn.Module):
|
|||
self.register_buffer("filter", filter)
|
||||
self.conv1d_block = None
|
||||
if C is not None:
|
||||
self.conv1d_block = (nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False), 1)
|
||||
self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),]
|
||||
self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1))
|
||||
self.conv1d_block[0].requires_grad_(False)
|
||||
|
||||
#input [B, C, T]
|
||||
def forward(self, x):
|
||||
if self.conv1d_block[0].weight.device != x.device:
|
||||
self.conv1d_block[0] = self.conv1d_block[0].to(x.device)
|
||||
if self.conv1d_block is None:
|
||||
_, C, _ = x.shape
|
||||
|
||||
|
|
|
@ -22,13 +22,13 @@ class UpSample1d(nn.Module):
|
|||
self.register_buffer("filter", filter)
|
||||
self.conv_transpose1d_block = None
|
||||
if C is not None:
|
||||
self.conv_transpose1d_block = (nn.ConvTranspose1d(C,
|
||||
self.conv_transpose1d_block = [nn.ConvTranspose1d(C,
|
||||
C,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
groups=C,
|
||||
bias=False
|
||||
), 1)
|
||||
),]
|
||||
self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone())
|
||||
self.conv_transpose1d_block[0].requires_grad_(False)
|
||||
|
||||
|
@ -36,6 +36,8 @@ class UpSample1d(nn.Module):
|
|||
|
||||
# x: [B, C, T]
|
||||
def forward(self, x, C=None):
|
||||
if self.conv_transpose1d_block[0].weight.device != x.device:
|
||||
self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device)
|
||||
if self.conv_transpose1d_block is None:
|
||||
if C is None:
|
||||
_, C, _ = x.shape
|
||||
|
|
Loading…
Reference in New Issue