This commit is contained in:
ylzz1997 2023-07-22 20:32:37 +08:00
parent 277ed41e23
commit 26329ff059
1 changed files with 2 additions and 13 deletions

View File

@ -275,13 +275,7 @@ class FastAttention(nn.Module):
self.no_projection = no_projection
self.causal = causal
if causal:
try:
import fast_transformers.causal_product.causal_product_cuda
self.causal_linear_fn = partial(causal_linear_attention)
except ImportError:
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
self.causal_linear_fn = causal_linear_attention_noncuda
@torch.no_grad()
def redraw_projection_matrix(self):
projections = self.create_projection()
@ -294,11 +288,6 @@ class FastAttention(nn.Module):
if self.no_projection:
q = q.softmax(dim = -1)
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
elif self.generalized_attention:
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
q, k = map(create_kernel, (q, k))
else:
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
@ -341,7 +330,7 @@ class SelfAttention(nn.Module):
#torch.nn.init.zeros_(self.name_embedding)
#print (torch.sum(self.name_embedding))
def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
cross_attend = exists(context)