ruff fix
This commit is contained in:
parent
277ed41e23
commit
26329ff059
|
@ -275,13 +275,7 @@ class FastAttention(nn.Module):
|
|||
self.no_projection = no_projection
|
||||
|
||||
self.causal = causal
|
||||
if causal:
|
||||
try:
|
||||
import fast_transformers.causal_product.causal_product_cuda
|
||||
self.causal_linear_fn = partial(causal_linear_attention)
|
||||
except ImportError:
|
||||
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
|
||||
self.causal_linear_fn = causal_linear_attention_noncuda
|
||||
|
||||
@torch.no_grad()
|
||||
def redraw_projection_matrix(self):
|
||||
projections = self.create_projection()
|
||||
|
@ -294,11 +288,6 @@ class FastAttention(nn.Module):
|
|||
if self.no_projection:
|
||||
q = q.softmax(dim = -1)
|
||||
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
|
||||
|
||||
elif self.generalized_attention:
|
||||
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
|
||||
q, k = map(create_kernel, (q, k))
|
||||
|
||||
else:
|
||||
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
||||
|
||||
|
@ -341,7 +330,7 @@ class SelfAttention(nn.Module):
|
|||
#torch.nn.init.zeros_(self.name_embedding)
|
||||
#print (torch.sum(self.name_embedding))
|
||||
def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
|
||||
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
|
||||
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
|
||||
|
||||
cross_attend = exists(context)
|
||||
|
||||
|
|
Loading…
Reference in New Issue