Accelerate up random slice segments

This commit is contained in:
Leng Yue 2023-11-09 22:35:05 -08:00 committed by GitHub
parent 730930d337
commit 0ee0b0899e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 8 deletions

View File

@ -65,20 +65,19 @@ def rand_gumbel_like(x):
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
# Slice segments
gather_indices = ids_str[:, None, None] + torch.arange(
segment_size, device=x.device
)
return torch.gather(x, 2, gather_indices)
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str