diff --git a/modules/commons.py b/modules/commons.py index 761379d..238845d 100644 --- a/modules/commons.py +++ b/modules/commons.py @@ -65,20 +65,19 @@ def rand_gumbel_like(x): def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret + # Slice segments + gather_indices = ids_str[:, None, None] + torch.arange( + segment_size, device=x.device + ) + return torch.gather(x, 2, gather_indices) def rand_slice_segments(x, x_lengths=None, segment_size=4): b, d, t = x.size() if x_lengths is None: x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) + ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) ret = slice_segments(x, ids_str, segment_size) return ret, ids_str