Accelerate up random slice segments
This commit is contained in:
parent
730930d337
commit
0ee0b0899e
|
@ -65,20 +65,19 @@ def rand_gumbel_like(x):
|
||||||
|
|
||||||
|
|
||||||
def slice_segments(x, ids_str, segment_size=4):
|
def slice_segments(x, ids_str, segment_size=4):
|
||||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
# Slice segments
|
||||||
for i in range(x.size(0)):
|
gather_indices = ids_str[:, None, None] + torch.arange(
|
||||||
idx_str = ids_str[i]
|
segment_size, device=x.device
|
||||||
idx_end = idx_str + segment_size
|
)
|
||||||
ret[i] = x[i, :, idx_str:idx_end]
|
return torch.gather(x, 2, gather_indices)
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||||
b, d, t = x.size()
|
b, d, t = x.size()
|
||||||
if x_lengths is None:
|
if x_lengths is None:
|
||||||
x_lengths = t
|
x_lengths = t
|
||||||
ids_str_max = x_lengths - segment_size + 1
|
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
|
||||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||||
ret = slice_segments(x, ids_str, segment_size)
|
ret = slice_segments(x, ids_str, segment_size)
|
||||||
return ret, ids_str
|
return ret, ids_str
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue