ruff fix
This commit is contained in:
parent
277ed41e23
commit
26329ff059
|
@ -275,13 +275,7 @@ class FastAttention(nn.Module):
|
||||||
self.no_projection = no_projection
|
self.no_projection = no_projection
|
||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
if causal:
|
|
||||||
try:
|
|
||||||
import fast_transformers.causal_product.causal_product_cuda
|
|
||||||
self.causal_linear_fn = partial(causal_linear_attention)
|
|
||||||
except ImportError:
|
|
||||||
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
|
|
||||||
self.causal_linear_fn = causal_linear_attention_noncuda
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def redraw_projection_matrix(self):
|
def redraw_projection_matrix(self):
|
||||||
projections = self.create_projection()
|
projections = self.create_projection()
|
||||||
|
@ -294,11 +288,6 @@ class FastAttention(nn.Module):
|
||||||
if self.no_projection:
|
if self.no_projection:
|
||||||
q = q.softmax(dim = -1)
|
q = q.softmax(dim = -1)
|
||||||
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
|
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
|
||||||
|
|
||||||
elif self.generalized_attention:
|
|
||||||
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
|
|
||||||
q, k = map(create_kernel, (q, k))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
||||||
|
|
||||||
|
@ -341,7 +330,7 @@ class SelfAttention(nn.Module):
|
||||||
#torch.nn.init.zeros_(self.name_embedding)
|
#torch.nn.init.zeros_(self.name_embedding)
|
||||||
#print (torch.sum(self.name_embedding))
|
#print (torch.sum(self.name_embedding))
|
||||||
def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
|
def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
|
||||||
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
|
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
|
||||||
|
|
||||||
cross_attend = exists(context)
|
cross_attend = exists(context)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue