Delete fairseq_onnx directory
This commit is contained in:
parent
cf1724c1eb
commit
120444b293
|
@ -1,489 +0,0 @@
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the MIT license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from omegaconf import II
|
|
||||||
|
|
||||||
from fairseq import utils
|
|
||||||
from fairseq.data.data_utils import compute_mask_indices
|
|
||||||
from fairseq.data.dictionary import Dictionary
|
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
|
||||||
from fairseq.models import BaseFairseqModel, register_model
|
|
||||||
from fairseq.models.wav2vec.wav2vec2 import (
|
|
||||||
EXTRACTOR_MODE_CHOICES,
|
|
||||||
MASKING_DISTRIBUTION_CHOICES,
|
|
||||||
LAYER_TYPE_CHOICES,
|
|
||||||
ConvFeatureExtractionModel,
|
|
||||||
TransformerEncoder,
|
|
||||||
)
|
|
||||||
from fairseq.modules import LayerNorm
|
|
||||||
from fairseq.tasks.hubert_pretraining import (
|
|
||||||
HubertPretrainingConfig,
|
|
||||||
HubertPretrainingTask,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HubertConfig(FairseqDataclass):
|
|
||||||
label_rate: float = II("task.label_rate")
|
|
||||||
|
|
||||||
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
|
||||||
default="default",
|
|
||||||
metadata={
|
|
||||||
"help": "mode for feature extractor. default has a single group "
|
|
||||||
"norm with d groups in the first conv block, whereas layer_norm "
|
|
||||||
"has layer norms in every block (meant to use with normalize=True)"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
encoder_layers: int = field(
|
|
||||||
default=12, metadata={"help": "num encoder layers in the transformer"}
|
|
||||||
)
|
|
||||||
encoder_embed_dim: int = field(
|
|
||||||
default=768, metadata={"help": "encoder embedding dimension"}
|
|
||||||
)
|
|
||||||
encoder_ffn_embed_dim: int = field(
|
|
||||||
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
|
|
||||||
)
|
|
||||||
encoder_attention_heads: int = field(
|
|
||||||
default=12, metadata={"help": "num encoder attention heads"}
|
|
||||||
)
|
|
||||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
|
||||||
default="gelu", metadata={"help": "activation function to use"}
|
|
||||||
)
|
|
||||||
layer_type: LAYER_TYPE_CHOICES = field(
|
|
||||||
default="transformer", metadata={"help": "layer type in encoder"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# dropouts
|
|
||||||
dropout: float = field(
|
|
||||||
default=0.1,
|
|
||||||
metadata={"help": "dropout probability for the transformer"},
|
|
||||||
)
|
|
||||||
attention_dropout: float = field(
|
|
||||||
default=0.1,
|
|
||||||
metadata={"help": "dropout probability for attention weights"},
|
|
||||||
)
|
|
||||||
activation_dropout: float = field(
|
|
||||||
default=0.0,
|
|
||||||
metadata={"help": "dropout probability after activation in FFN"},
|
|
||||||
)
|
|
||||||
encoder_layerdrop: float = field(
|
|
||||||
default=0.0,
|
|
||||||
metadata={"help": "probability of dropping a tarnsformer layer"},
|
|
||||||
)
|
|
||||||
dropout_input: float = field(
|
|
||||||
default=0.0,
|
|
||||||
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
|
||||||
)
|
|
||||||
dropout_features: float = field(
|
|
||||||
default=0.0,
|
|
||||||
metadata={"help": "dropout to apply to the features (after feat extr)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
final_dim: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={
|
|
||||||
"help": "project final representations and targets to this many "
|
|
||||||
"dimensions. set to encoder_embed_dim is <= 0"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
untie_final_proj: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "use separate projection for each target"},
|
|
||||||
)
|
|
||||||
layer_norm_first: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "apply layernorm first in the transformer"},
|
|
||||||
)
|
|
||||||
conv_feature_layers: str = field(
|
|
||||||
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
|
||||||
metadata={
|
|
||||||
"help": "string describing convolutional feature extraction "
|
|
||||||
"layers in form of a python list that contains "
|
|
||||||
"[(dim, kernel_size, stride), ...]"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
conv_bias: bool = field(
|
|
||||||
default=False, metadata={"help": "include bias in conv encoder"}
|
|
||||||
)
|
|
||||||
logit_temp: float = field(
|
|
||||||
default=0.1, metadata={"help": "temperature to divide logits by"}
|
|
||||||
)
|
|
||||||
target_glu: bool = field(
|
|
||||||
default=False, metadata={"help": "adds projection + glu to targets"}
|
|
||||||
)
|
|
||||||
feature_grad_mult: float = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "multiply feature extractor var grads by this"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# masking
|
|
||||||
mask_length: int = field(default=10, metadata={"help": "mask length"})
|
|
||||||
mask_prob: float = field(
|
|
||||||
default=0.65,
|
|
||||||
metadata={"help": "probability of replacing a token with mask"},
|
|
||||||
)
|
|
||||||
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
||||||
default="static", metadata={"help": "how to choose mask length"}
|
|
||||||
)
|
|
||||||
mask_other: float = field(
|
|
||||||
default=0,
|
|
||||||
metadata={
|
|
||||||
"help": "secondary mask argument "
|
|
||||||
"(used for more complex distributions), "
|
|
||||||
"see help in compute_mask_indicesh"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
no_mask_overlap: bool = field(
|
|
||||||
default=False, metadata={"help": "whether to allow masks to overlap"}
|
|
||||||
)
|
|
||||||
mask_min_space: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# channel masking
|
|
||||||
mask_channel_length: int = field(
|
|
||||||
default=10,
|
|
||||||
metadata={"help": "length of the mask for features (channels)"},
|
|
||||||
)
|
|
||||||
mask_channel_prob: float = field(
|
|
||||||
default=0.0,
|
|
||||||
metadata={"help": "probability of replacing a feature with 0"},
|
|
||||||
)
|
|
||||||
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
|
||||||
default="static",
|
|
||||||
metadata={"help": "how to choose mask length for channel masking"},
|
|
||||||
)
|
|
||||||
mask_channel_other: float = field(
|
|
||||||
default=0,
|
|
||||||
metadata={
|
|
||||||
"help": "secondary mask argument "
|
|
||||||
"(used for more complex distributions), "
|
|
||||||
"see help in compute_mask_indicesh"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
no_mask_channel_overlap: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "whether to allow channel masks to overlap"},
|
|
||||||
)
|
|
||||||
mask_channel_min_space: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# positional embeddings
|
|
||||||
conv_pos: int = field(
|
|
||||||
default=128,
|
|
||||||
metadata={"help": "number of filters for convolutional positional embeddings"},
|
|
||||||
)
|
|
||||||
conv_pos_groups: int = field(
|
|
||||||
default=16,
|
|
||||||
metadata={"help": "number of groups for convolutional positional embedding"},
|
|
||||||
)
|
|
||||||
|
|
||||||
latent_temp: Tuple[float, float, float] = field(
|
|
||||||
default=(2, 0.5, 0.999995),
|
|
||||||
metadata={"help": "legacy (to be removed)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# loss computation
|
|
||||||
skip_masked: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "skip computing losses over masked frames"},
|
|
||||||
)
|
|
||||||
skip_nomask: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "skip computing losses over unmasked frames"},
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoint_activations: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "recompute activations and save memory for extra compute"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# FP16 optimization
|
|
||||||
required_seq_len_multiple: int = field(
|
|
||||||
default=2,
|
|
||||||
metadata={
|
|
||||||
"help": "pad the input to encoder such that the sequence length is divisible by multiple"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Conformer
|
|
||||||
depthwise_conv_kernel_size: int = field(
|
|
||||||
default=31,
|
|
||||||
metadata={
|
|
||||||
"help": "depthwise-conv-kernel-size for convolution in conformer layer"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
attn_type: str = field(
|
|
||||||
default="",
|
|
||||||
metadata={"help": "if espnet use ESPNET MHA"},
|
|
||||||
)
|
|
||||||
pos_enc_type: str = field(
|
|
||||||
default="abs",
|
|
||||||
metadata={"help": "Positional encoding type to use in conformer"},
|
|
||||||
)
|
|
||||||
fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
|
|
||||||
|
|
||||||
|
|
||||||
@register_model("hubert", dataclass=HubertConfig)
|
|
||||||
class HubertModel(BaseFairseqModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cfg: HubertConfig,
|
|
||||||
task_cfg: HubertPretrainingConfig,
|
|
||||||
dictionaries: List[Dictionary],
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
logger.info(f"HubertModel Config: {cfg}")
|
|
||||||
|
|
||||||
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
|
||||||
self.embed = feature_enc_layers[-1][0]
|
|
||||||
|
|
||||||
self.feature_extractor = ConvFeatureExtractionModel(
|
|
||||||
conv_layers=feature_enc_layers,
|
|
||||||
dropout=0.0,
|
|
||||||
mode=cfg.extractor_mode,
|
|
||||||
conv_bias=cfg.conv_bias,
|
|
||||||
)
|
|
||||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
|
||||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
|
||||||
|
|
||||||
self.post_extract_proj = (
|
|
||||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
|
||||||
if self.embed != cfg.encoder_embed_dim
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mask_prob = cfg.mask_prob
|
|
||||||
self.mask_selection = cfg.mask_selection
|
|
||||||
self.mask_other = cfg.mask_other
|
|
||||||
self.mask_length = cfg.mask_length
|
|
||||||
self.no_mask_overlap = cfg.no_mask_overlap
|
|
||||||
self.mask_min_space = cfg.mask_min_space
|
|
||||||
|
|
||||||
self.mask_channel_prob = cfg.mask_channel_prob
|
|
||||||
self.mask_channel_selection = cfg.mask_channel_selection
|
|
||||||
self.mask_channel_other = cfg.mask_channel_other
|
|
||||||
self.mask_channel_length = cfg.mask_channel_length
|
|
||||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
|
||||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
|
||||||
|
|
||||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
|
||||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
|
||||||
|
|
||||||
self.feature_grad_mult = cfg.feature_grad_mult
|
|
||||||
self.logit_temp = cfg.logit_temp
|
|
||||||
self.skip_masked = cfg.skip_masked
|
|
||||||
self.skip_nomask = cfg.skip_nomask
|
|
||||||
|
|
||||||
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
|
||||||
|
|
||||||
self.mask_emb = nn.Parameter(
|
|
||||||
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder = TransformerEncoder(cfg)
|
|
||||||
self.layer_norm = LayerNorm(self.embed)
|
|
||||||
|
|
||||||
self.target_glu = None
|
|
||||||
if cfg.target_glu:
|
|
||||||
self.target_glu = nn.Sequential(
|
|
||||||
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.untie_final_proj = cfg.untie_final_proj
|
|
||||||
if self.untie_final_proj:
|
|
||||||
self.final_proj = nn.Linear(
|
|
||||||
cfg.encoder_embed_dim, final_dim * len(dictionaries)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
|
||||||
|
|
||||||
# modules below are not needed during fine-tuning
|
|
||||||
if any([d is None for d in dictionaries]):
|
|
||||||
logger.info("cannot find dictionary. assume will be used for fine-tuning")
|
|
||||||
else:
|
|
||||||
self.num_classes = [len(d) for d in dictionaries]
|
|
||||||
self.label_embs_concat = nn.Parameter(
|
|
||||||
torch.FloatTensor(sum(self.num_classes), final_dim)
|
|
||||||
)
|
|
||||||
nn.init.uniform_(self.label_embs_concat)
|
|
||||||
|
|
||||||
def upgrade_state_dict_named(self, state_dict, name):
|
|
||||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
|
||||||
|
|
||||||
super().upgrade_state_dict_named(state_dict, name)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_model(cls, cfg: HubertConfig, task: HubertPretrainingTask):
|
|
||||||
"""Build a new model instance."""
|
|
||||||
|
|
||||||
model = HubertModel(cfg, task.cfg, task.dictionaries)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def apply_mask(self, x, padding_mask, target_list):
|
|
||||||
B, T, C = x.shape
|
|
||||||
if self.mask_prob > 0:
|
|
||||||
mask_indices = compute_mask_indices(
|
|
||||||
(B, T),
|
|
||||||
padding_mask,
|
|
||||||
self.mask_prob,
|
|
||||||
self.mask_length,
|
|
||||||
self.mask_selection,
|
|
||||||
self.mask_other,
|
|
||||||
min_masks=2,
|
|
||||||
no_overlap=self.no_mask_overlap,
|
|
||||||
min_space=self.mask_min_space,
|
|
||||||
)
|
|
||||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
|
||||||
x[mask_indices] = self.mask_emb
|
|
||||||
else:
|
|
||||||
mask_indices = None
|
|
||||||
|
|
||||||
if self.mask_channel_prob > 0:
|
|
||||||
mask_channel_indices = compute_mask_indices(
|
|
||||||
(B, C),
|
|
||||||
None,
|
|
||||||
self.mask_channel_prob,
|
|
||||||
self.mask_channel_length,
|
|
||||||
self.mask_channel_selection,
|
|
||||||
self.mask_channel_other,
|
|
||||||
no_overlap=self.no_mask_channel_overlap,
|
|
||||||
min_space=self.mask_channel_min_space,
|
|
||||||
)
|
|
||||||
mask_channel_indices = (
|
|
||||||
torch.from_numpy(mask_channel_indices)
|
|
||||||
.to(x.device)
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(-1, T, -1)
|
|
||||||
)
|
|
||||||
x[mask_channel_indices] = 0
|
|
||||||
|
|
||||||
return x, mask_indices
|
|
||||||
|
|
||||||
def compute_nce(self, x, pos, negs):
|
|
||||||
neg_is_pos = (pos == negs).all(-1)
|
|
||||||
pos = pos.unsqueeze(0)
|
|
||||||
targets = torch.cat([pos, negs], dim=0)
|
|
||||||
|
|
||||||
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
|
||||||
logits /= self.logit_temp
|
|
||||||
if neg_is_pos.any():
|
|
||||||
logits[1:][neg_is_pos] = float("-inf")
|
|
||||||
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.feature_grad_mult > 0:
|
|
||||||
features = self.feature_extractor(source)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
features = self.feature_extractor(source)
|
|
||||||
return features
|
|
||||||
|
|
||||||
def forward_targets(
|
|
||||||
self,
|
|
||||||
features: torch.Tensor,
|
|
||||||
target_list: List[torch.Tensor],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Trim features to ensure labels exist and then get aligned labels
|
|
||||||
feat_tsz = features.size(2)
|
|
||||||
targ_tsz = min([t.size(1) for t in target_list])
|
|
||||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
|
||||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
|
||||||
features = features[..., :feat_tsz]
|
|
||||||
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
|
||||||
target_list = [t[:, target_inds.long()] for t in target_list]
|
|
||||||
return features, target_list
|
|
||||||
|
|
||||||
def forward_padding_mask(
|
|
||||||
self,
|
|
||||||
features: torch.Tensor,
|
|
||||||
padding_mask: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
extra = padding_mask.size(1) % features.size(1)
|
|
||||||
if extra > 0:
|
|
||||||
padding_mask = padding_mask[:, :-extra]
|
|
||||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
|
||||||
padding_mask = padding_mask.all(-1)
|
|
||||||
return padding_mask
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
source: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
source = source.squeeze(0)
|
|
||||||
output_layer = 9
|
|
||||||
features = self.forward_features(source)
|
|
||||||
features = features.transpose(1, 2)
|
|
||||||
features = self.layer_norm(features)
|
|
||||||
padding_mask = torch.zeros(size=(1,features.shape[1]), dtype=torch.bool)
|
|
||||||
if self.post_extract_proj is not None:
|
|
||||||
features = self.post_extract_proj(features)
|
|
||||||
features = self.dropout_input(features)
|
|
||||||
x = features
|
|
||||||
x = self.encoder(
|
|
||||||
x,
|
|
||||||
padding_mask=padding_mask,
|
|
||||||
layer=None if output_layer is None else output_layer - 1,
|
|
||||||
)
|
|
||||||
return self.final_proj(x)
|
|
||||||
|
|
||||||
def extract_features(
|
|
||||||
self,
|
|
||||||
source: torch.Tensor,
|
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
|
||||||
mask: bool = False,
|
|
||||||
ret_conv: bool = False,
|
|
||||||
output_layer: Optional[int] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
res = self.forward(
|
|
||||||
source,
|
|
||||||
padding_mask=padding_mask,
|
|
||||||
mask=mask,
|
|
||||||
features_only=True,
|
|
||||||
output_layer=output_layer,
|
|
||||||
)
|
|
||||||
feature = res["features"] if ret_conv else res["x"]
|
|
||||||
return feature, res["padding_mask"]
|
|
||||||
|
|
||||||
def get_logits(self, net_output, is_masked=True):
|
|
||||||
if is_masked:
|
|
||||||
logits_list = net_output["logit_m_list"]
|
|
||||||
else:
|
|
||||||
logits_list = net_output["logit_u_list"]
|
|
||||||
logits_list = [x.float() for x in logits_list if x is not None]
|
|
||||||
return logits_list
|
|
||||||
|
|
||||||
def get_targets(self, net_output, is_masked=True):
|
|
||||||
logits_list = self.get_logits(net_output, is_masked)
|
|
||||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
|
||||||
return targets_list
|
|
||||||
|
|
||||||
def get_extra_losses(self, net_output):
|
|
||||||
extra_losses = []
|
|
||||||
names = []
|
|
||||||
|
|
||||||
if "features_pen" in net_output:
|
|
||||||
extra_losses.append(net_output["features_pen"])
|
|
||||||
names.append("features_pen")
|
|
||||||
|
|
||||||
return extra_losses, names
|
|
||||||
|
|
||||||
def remove_pretraining_modules(self):
|
|
||||||
self.target_glu = None
|
|
||||||
self.final_proj = None
|
|
|
@ -1,21 +0,0 @@
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the MIT license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
|
||||||
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
|
||||||
if x is None:
|
|
||||||
return None, 0
|
|
||||||
tsz = x.size(dim)
|
|
||||||
m = tsz / multiple
|
|
||||||
remainder = math.ceil(m) * multiple - tsz
|
|
||||||
if m.is_integer():
|
|
||||||
return x, 0
|
|
||||||
pad_offset = (0,) * (-1 - dim) * 2
|
|
||||||
|
|
||||||
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
|
|
@ -1,630 +0,0 @@
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the MIT license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from omegaconf import II
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
|
||||||
from fairseq.models import BaseFairseqModel, register_model
|
|
||||||
from fairseq.modules import (
|
|
||||||
Fp32GroupNorm,
|
|
||||||
Fp32LayerNorm,
|
|
||||||
GumbelVectorQuantizer,
|
|
||||||
KmeansVectorQuantizer,
|
|
||||||
TransposeLast,
|
|
||||||
)
|
|
||||||
from fairseq.tasks import FairseqTask
|
|
||||||
from fairseq.utils import buffered_arange
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"])
|
|
||||||
PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"])
|
|
||||||
ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"])
|
|
||||||
VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Wav2VecConfig(FairseqDataclass):
|
|
||||||
prediction_steps: int = field(
|
|
||||||
default=12, metadata={"help": "number of steps ahead to predict"}
|
|
||||||
)
|
|
||||||
sample_distance: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "sample distance from target. does not work properly with cross-sampling"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
cross_sample_negatives: int = field(
|
|
||||||
default=0, metadata={"help": "num of cross sampled negatives"}
|
|
||||||
)
|
|
||||||
num_negatives: int = field(
|
|
||||||
default=10, metadata={"help": "num of sampled negatives"}
|
|
||||||
)
|
|
||||||
conv_feature_layers: str = field(
|
|
||||||
default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]",
|
|
||||||
metadata={
|
|
||||||
"help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
conv_aggregator_layers: str = field(
|
|
||||||
default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]",
|
|
||||||
metadata={
|
|
||||||
"help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dropout: float = field(
|
|
||||||
default=0.0, metadata={"help": "dropout to apply within the model"}
|
|
||||||
)
|
|
||||||
dropout_features: float = field(
|
|
||||||
default=0.0, metadata={"help": "dropout to apply to the features"}
|
|
||||||
)
|
|
||||||
dropout_agg: float = field(
|
|
||||||
default=0.0, metadata={"help": "dropout to apply after aggregation step"}
|
|
||||||
)
|
|
||||||
aggregator: AGGREGATOR_CHOICES = field(
|
|
||||||
default="cnn", metadata={"help": "type of aggregator to use"}
|
|
||||||
)
|
|
||||||
gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"})
|
|
||||||
no_conv_bias: bool = field(
|
|
||||||
default=False, metadata={"help": "if set, does not learn bias for conv layers"}
|
|
||||||
)
|
|
||||||
agg_zero_pad: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "if set, zero pads in aggregator instead of repl pad"},
|
|
||||||
)
|
|
||||||
skip_connections_feat: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "if set, adds skip connections to the feature extractor"},
|
|
||||||
)
|
|
||||||
skip_connections_agg: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "if set, adds skip connections to the aggregator"},
|
|
||||||
)
|
|
||||||
residual_scale: float = field(
|
|
||||||
default=0.5, metadata={"help": "scales residual by sqrt(value)"}
|
|
||||||
)
|
|
||||||
log_compression: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "if set, adds a log compression to feature extractor"},
|
|
||||||
)
|
|
||||||
balanced_classes: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "if set, loss is scaled to balance for number of negatives"},
|
|
||||||
)
|
|
||||||
project_features: PROJECT_FEATURES_CHOICES = field(
|
|
||||||
default="none",
|
|
||||||
metadata={
|
|
||||||
"help": "if not none, features are projected using the (same or new) aggregator"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
non_affine_group_norm: bool = field(
|
|
||||||
default=False, metadata={"help": "if set, group norm is not affine"}
|
|
||||||
)
|
|
||||||
offset: str = field(
|
|
||||||
default="auto",
|
|
||||||
metadata={
|
|
||||||
"help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
activation: ACTIVATION_CHOICES = field(
|
|
||||||
default="relu",
|
|
||||||
metadata={
|
|
||||||
"help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
vq_type: VQ_TYPE_CHOICES = field(
|
|
||||||
default="none", metadata={"help": "which type of quantizer to use"}
|
|
||||||
)
|
|
||||||
vq_vars: int = field(
|
|
||||||
default=320,
|
|
||||||
metadata={"help": "project to this many vector quantized variables per group"},
|
|
||||||
)
|
|
||||||
vq_groups: int = field(
|
|
||||||
default=2, metadata={"help": "number of groups of latent variables"}
|
|
||||||
)
|
|
||||||
vq_dim: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={
|
|
||||||
"help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
vq_depth: int = field(
|
|
||||||
default=1, metadata={"help": "number of layers for vq weight projection"}
|
|
||||||
)
|
|
||||||
combine_groups: bool = field(
|
|
||||||
default=False, metadata={"help": "if set, variables are shared among groups"}
|
|
||||||
)
|
|
||||||
vq_temp: Tuple[float, float, float] = field(
|
|
||||||
default=(2.0, 0.5, 0.999995),
|
|
||||||
metadata={
|
|
||||||
"help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
vq_gamma: float = field(
|
|
||||||
default=0.25,
|
|
||||||
metadata={"help": "gamma parameter for kmeans style vector quantization"},
|
|
||||||
)
|
|
||||||
infonce: bool = II("criterion.infonce")
|
|
||||||
|
|
||||||
|
|
||||||
@register_model("wav2vec", dataclass=Wav2VecConfig)
|
|
||||||
class Wav2VecModel(BaseFairseqModel):
|
|
||||||
@classmethod
|
|
||||||
def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask):
|
|
||||||
"""Build a new model instance."""
|
|
||||||
|
|
||||||
model = Wav2VecModel(cfg)
|
|
||||||
logger.info(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def __init__(self, cfg: Wav2VecConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.prediction_steps = cfg.prediction_steps
|
|
||||||
offset = cfg.offset
|
|
||||||
|
|
||||||
if cfg.activation == "relu":
|
|
||||||
activation = nn.ReLU()
|
|
||||||
elif cfg.activation == "gelu":
|
|
||||||
activation = nn.GELU()
|
|
||||||
else:
|
|
||||||
raise Exception("unknown activation " + cfg.activation)
|
|
||||||
|
|
||||||
feature_enc_layers = eval(cfg.conv_feature_layers)
|
|
||||||
self.feature_extractor = ConvFeatureExtractionModel(
|
|
||||||
conv_layers=feature_enc_layers,
|
|
||||||
dropout=0.0,
|
|
||||||
log_compression=cfg.log_compression,
|
|
||||||
skip_connections=cfg.skip_connections_feat,
|
|
||||||
residual_scale=cfg.residual_scale,
|
|
||||||
non_affine_group_norm=cfg.non_affine_group_norm,
|
|
||||||
activation=activation,
|
|
||||||
)
|
|
||||||
embed = feature_enc_layers[-1][0]
|
|
||||||
|
|
||||||
self.vector_quantizer = None
|
|
||||||
if cfg.vq_type == "gumbel":
|
|
||||||
self.vector_quantizer = GumbelVectorQuantizer(
|
|
||||||
dim=embed,
|
|
||||||
num_vars=cfg.vq_vars,
|
|
||||||
temp=cfg.vq_temp,
|
|
||||||
groups=cfg.vq_groups,
|
|
||||||
combine_groups=cfg.combine_groups,
|
|
||||||
vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
|
|
||||||
time_first=False,
|
|
||||||
activation=activation,
|
|
||||||
weight_proj_depth=cfg.vq_depth,
|
|
||||||
weight_proj_factor=2,
|
|
||||||
)
|
|
||||||
elif cfg.vq_type == "kmeans":
|
|
||||||
self.vector_quantizer = KmeansVectorQuantizer(
|
|
||||||
dim=embed,
|
|
||||||
num_vars=cfg.vq_vars,
|
|
||||||
groups=cfg.vq_groups,
|
|
||||||
combine_groups=cfg.combine_groups,
|
|
||||||
vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
|
|
||||||
time_first=False,
|
|
||||||
gamma=cfg.vq_gamma,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
cfg.vq_type == "none" or cfg.vq_type is None
|
|
||||||
), "Unknown quantizer type"
|
|
||||||
|
|
||||||
if cfg.offset == "auto":
|
|
||||||
jin = 0
|
|
||||||
rin = 0
|
|
||||||
for _, k, stride in feature_enc_layers:
|
|
||||||
if rin == 0:
|
|
||||||
rin = k
|
|
||||||
rin = rin + (k - 1) * jin
|
|
||||||
if jin == 0:
|
|
||||||
jin = stride
|
|
||||||
else:
|
|
||||||
jin *= stride
|
|
||||||
offset = math.ceil(rin / jin)
|
|
||||||
|
|
||||||
offset = int(offset)
|
|
||||||
|
|
||||||
def make_aggregator():
|
|
||||||
if cfg.aggregator == "cnn":
|
|
||||||
agg_layers = eval(cfg.conv_aggregator_layers)
|
|
||||||
agg_dim = agg_layers[-1][0]
|
|
||||||
feature_aggregator = ConvAggegator(
|
|
||||||
conv_layers=agg_layers,
|
|
||||||
embed=embed,
|
|
||||||
dropout=cfg.dropout,
|
|
||||||
skip_connections=cfg.skip_connections_agg,
|
|
||||||
residual_scale=cfg.residual_scale,
|
|
||||||
non_affine_group_norm=cfg.non_affine_group_norm,
|
|
||||||
conv_bias=not cfg.no_conv_bias,
|
|
||||||
zero_pad=cfg.agg_zero_pad,
|
|
||||||
activation=activation,
|
|
||||||
)
|
|
||||||
elif cfg.aggregator == "gru":
|
|
||||||
agg_dim = cfg.gru_dim
|
|
||||||
feature_aggregator = nn.Sequential(
|
|
||||||
TransposeLast(),
|
|
||||||
nn.GRU(
|
|
||||||
input_size=embed,
|
|
||||||
hidden_size=agg_dim,
|
|
||||||
num_layers=1,
|
|
||||||
dropout=cfg.dropout,
|
|
||||||
),
|
|
||||||
TransposeLast(deconstruct_idx=0),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception("unknown aggregator type " + cfg.aggregator)
|
|
||||||
|
|
||||||
return feature_aggregator, agg_dim
|
|
||||||
|
|
||||||
self.feature_aggregator, agg_dim = make_aggregator()
|
|
||||||
|
|
||||||
self.wav2vec_predictions = Wav2VecPredictionsModel(
|
|
||||||
in_dim=agg_dim,
|
|
||||||
out_dim=embed,
|
|
||||||
prediction_steps=cfg.prediction_steps,
|
|
||||||
n_negatives=cfg.num_negatives,
|
|
||||||
cross_sample_negatives=cfg.cross_sample_negatives,
|
|
||||||
sample_distance=cfg.sample_distance,
|
|
||||||
dropout=cfg.dropout,
|
|
||||||
offset=offset,
|
|
||||||
balanced_classes=cfg.balanced_classes,
|
|
||||||
infonce=cfg.infonce,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.dropout_feats = nn.Dropout(p=cfg.dropout_features)
|
|
||||||
self.dropout_agg = nn.Dropout(p=cfg.dropout_agg)
|
|
||||||
|
|
||||||
if cfg.project_features == "none":
|
|
||||||
self.project_features = None
|
|
||||||
elif cfg.project_features == "same":
|
|
||||||
self.project_features = self.feature_aggregator
|
|
||||||
elif cfg.project_features == "new":
|
|
||||||
self.project_features, _ = make_aggregator()
|
|
||||||
|
|
||||||
def forward(self, source):
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
features = self.feature_extractor(source)
|
|
||||||
if self.vector_quantizer:
|
|
||||||
q_res = self.vector_quantizer(features)
|
|
||||||
features = q_res["x"]
|
|
||||||
for k in q_res.keys():
|
|
||||||
if k != "x":
|
|
||||||
result[k] = q_res[k]
|
|
||||||
|
|
||||||
x = self.dropout_feats(features)
|
|
||||||
x = self.feature_aggregator(x)
|
|
||||||
x = self.dropout_agg(x)
|
|
||||||
|
|
||||||
if self.project_features is not None:
|
|
||||||
features = self.project_features(features)
|
|
||||||
x, targets = self.wav2vec_predictions(x, features)
|
|
||||||
result["cpc_logits"] = x
|
|
||||||
result["cpc_targets"] = targets
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def upgrade_state_dict_named(self, state_dict, name):
|
|
||||||
super().upgrade_state_dict_named(state_dict, name)
|
|
||||||
|
|
||||||
def max_positions(self):
|
|
||||||
"""Maximum length supported by the model."""
|
|
||||||
return sys.maxsize
|
|
||||||
|
|
||||||
def get_logits(self, net_output):
|
|
||||||
logits = net_output["cpc_logits"]
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def get_targets(self, sample, net_output):
|
|
||||||
t = net_output["cpc_targets"]
|
|
||||||
if isinstance(t, tuple):
|
|
||||||
t = t[0]
|
|
||||||
return t.contiguous()
|
|
||||||
|
|
||||||
def get_target_weights(self, targets, net_output):
|
|
||||||
targets = net_output["cpc_targets"]
|
|
||||||
if isinstance(targets, tuple) and targets[-1] is not None:
|
|
||||||
return targets[-1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_extra_losses(self, net_output):
|
|
||||||
loss = None
|
|
||||||
if "prob_perplexity" in net_output:
|
|
||||||
loss = net_output["num_vars"] - net_output["prob_perplexity"]
|
|
||||||
elif "kmeans_loss" in net_output:
|
|
||||||
loss = net_output["kmeans_loss"]
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
def norm_block(is_layer_norm, dim, affine=True):
|
|
||||||
if is_layer_norm:
|
|
||||||
mod = nn.Sequential(
|
|
||||||
TransposeLast(),
|
|
||||||
Fp32LayerNorm(dim, elementwise_affine=affine),
|
|
||||||
TransposeLast(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mod = Fp32GroupNorm(1, dim, affine=affine)
|
|
||||||
|
|
||||||
return mod
|
|
||||||
|
|
||||||
|
|
||||||
class ConvFeatureExtractionModel(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
conv_layers,
|
|
||||||
dropout,
|
|
||||||
log_compression,
|
|
||||||
skip_connections,
|
|
||||||
residual_scale,
|
|
||||||
non_affine_group_norm,
|
|
||||||
activation,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def block(n_in, n_out, k, stride):
|
|
||||||
return nn.Sequential(
|
|
||||||
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
|
|
||||||
nn.Dropout(p=dropout),
|
|
||||||
norm_block(
|
|
||||||
is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm
|
|
||||||
),
|
|
||||||
activation,
|
|
||||||
)
|
|
||||||
|
|
||||||
in_d = 1
|
|
||||||
self.conv_layers = nn.ModuleList()
|
|
||||||
for dim, k, stride in conv_layers:
|
|
||||||
self.conv_layers.append(block(in_d, dim, k, stride))
|
|
||||||
in_d = dim
|
|
||||||
|
|
||||||
self.log_compression = log_compression
|
|
||||||
self.skip_connections = skip_connections
|
|
||||||
self.residual_scale = math.sqrt(residual_scale)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# BxT -> BxCxT
|
|
||||||
x = x.unsqueeze(1)
|
|
||||||
|
|
||||||
for conv in self.conv_layers:
|
|
||||||
residual = x
|
|
||||||
x = conv(x)
|
|
||||||
if self.skip_connections and x.size(1) == residual.size(1):
|
|
||||||
tsz = x.size(2)
|
|
||||||
r_tsz = residual.size(2)
|
|
||||||
residual = residual[..., :: r_tsz // tsz][..., :tsz]
|
|
||||||
x = (x + residual) * self.residual_scale
|
|
||||||
|
|
||||||
if self.log_compression:
|
|
||||||
x = x.abs()
|
|
||||||
x = x + 1
|
|
||||||
x = x.log()
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroPad1d(nn.Module):
|
|
||||||
def __init__(self, pad_left, pad_right):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_left = pad_left
|
|
||||||
self.pad_right = pad_right
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return F.pad(x, (self.pad_left, self.pad_right))
|
|
||||||
|
|
||||||
|
|
||||||
class ConvAggegator(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
conv_layers,
|
|
||||||
embed,
|
|
||||||
dropout,
|
|
||||||
skip_connections,
|
|
||||||
residual_scale,
|
|
||||||
non_affine_group_norm,
|
|
||||||
conv_bias,
|
|
||||||
zero_pad,
|
|
||||||
activation,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def block(n_in, n_out, k, stride):
|
|
||||||
# padding dims only really make sense for stride = 1
|
|
||||||
ka = k // 2
|
|
||||||
kb = ka - 1 if k % 2 == 0 else ka
|
|
||||||
|
|
||||||
pad = (
|
|
||||||
ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0))
|
|
||||||
)
|
|
||||||
|
|
||||||
return nn.Sequential(
|
|
||||||
pad,
|
|
||||||
nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias),
|
|
||||||
nn.Dropout(p=dropout),
|
|
||||||
norm_block(False, n_out, affine=not non_affine_group_norm),
|
|
||||||
activation,
|
|
||||||
)
|
|
||||||
|
|
||||||
in_d = embed
|
|
||||||
self.conv_layers = nn.ModuleList()
|
|
||||||
self.residual_proj = nn.ModuleList()
|
|
||||||
for dim, k, stride in conv_layers:
|
|
||||||
if in_d != dim and skip_connections:
|
|
||||||
self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False))
|
|
||||||
else:
|
|
||||||
self.residual_proj.append(None)
|
|
||||||
|
|
||||||
self.conv_layers.append(block(in_d, dim, k, stride))
|
|
||||||
in_d = dim
|
|
||||||
self.conv_layers = nn.Sequential(*self.conv_layers)
|
|
||||||
self.skip_connections = skip_connections
|
|
||||||
self.residual_scale = math.sqrt(residual_scale)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for rproj, conv in zip(self.residual_proj, self.conv_layers):
|
|
||||||
residual = x
|
|
||||||
x = conv(x)
|
|
||||||
if self.skip_connections:
|
|
||||||
if rproj is not None:
|
|
||||||
residual = rproj(residual)
|
|
||||||
x = (x + residual) * self.residual_scale
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Wav2VecPredictionsModel(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_dim,
|
|
||||||
out_dim,
|
|
||||||
prediction_steps,
|
|
||||||
n_negatives,
|
|
||||||
cross_sample_negatives,
|
|
||||||
sample_distance,
|
|
||||||
dropout,
|
|
||||||
offset,
|
|
||||||
balanced_classes,
|
|
||||||
infonce,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.n_negatives = n_negatives
|
|
||||||
self.cross_sample_negatives = cross_sample_negatives
|
|
||||||
self.sample_distance = sample_distance
|
|
||||||
self.project_to_steps = nn.ConvTranspose2d(
|
|
||||||
in_dim, out_dim, (1, prediction_steps)
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(p=dropout)
|
|
||||||
self.offset = offset
|
|
||||||
self.balanced_classes = balanced_classes
|
|
||||||
self.infonce = infonce
|
|
||||||
|
|
||||||
def sample_negatives(self, y):
|
|
||||||
bsz, fsz, tsz = y.shape
|
|
||||||
|
|
||||||
y = y.transpose(0, 1) # BCT -> CBT
|
|
||||||
y = y.contiguous().view(fsz, -1) # CBT => C(BxT)
|
|
||||||
|
|
||||||
cross_high = tsz * bsz
|
|
||||||
high = tsz if self.sample_distance is None else min(tsz, self.sample_distance)
|
|
||||||
assert high > 1
|
|
||||||
|
|
||||||
neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz))
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
if self.n_negatives > 0:
|
|
||||||
tszs = (
|
|
||||||
buffered_arange(tsz)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand(-1, self.n_negatives)
|
|
||||||
.flatten()
|
|
||||||
)
|
|
||||||
|
|
||||||
neg_idxs = torch.randint(
|
|
||||||
low=0, high=high - 1, size=(bsz, self.n_negatives * tsz)
|
|
||||||
)
|
|
||||||
neg_idxs[neg_idxs >= tszs] += 1
|
|
||||||
|
|
||||||
if self.cross_sample_negatives > 0:
|
|
||||||
tszs = (
|
|
||||||
buffered_arange(tsz)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand(-1, self.cross_sample_negatives)
|
|
||||||
.flatten()
|
|
||||||
)
|
|
||||||
|
|
||||||
cross_neg_idxs = torch.randint(
|
|
||||||
low=0,
|
|
||||||
high=cross_high - 1,
|
|
||||||
size=(bsz, self.cross_sample_negatives * tsz),
|
|
||||||
)
|
|
||||||
cross_neg_idxs[cross_neg_idxs >= tszs] += 1
|
|
||||||
|
|
||||||
if self.n_negatives > 0:
|
|
||||||
for i in range(1, bsz):
|
|
||||||
neg_idxs[i] += i * high
|
|
||||||
else:
|
|
||||||
neg_idxs = cross_neg_idxs
|
|
||||||
|
|
||||||
if self.cross_sample_negatives > 0 and self.n_negatives > 0:
|
|
||||||
neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
|
|
||||||
|
|
||||||
negs = y[..., neg_idxs.view(-1)]
|
|
||||||
negs = negs.view(
|
|
||||||
fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz
|
|
||||||
).permute(
|
|
||||||
2, 1, 0, 3
|
|
||||||
) # to NxBxCxT
|
|
||||||
|
|
||||||
return negs
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
|
|
||||||
x = x.unsqueeze(-1)
|
|
||||||
x = self.project_to_steps(x) # BxCxTxS
|
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
negatives = self.sample_negatives(y)
|
|
||||||
y = y.unsqueeze(0)
|
|
||||||
targets = torch.cat([y, negatives], dim=0) # Copies x B x C x T
|
|
||||||
|
|
||||||
copies = targets.size(0)
|
|
||||||
bsz, dim, tsz, steps = x.shape
|
|
||||||
steps = min(steps, tsz - self.offset)
|
|
||||||
|
|
||||||
predictions = x.new(
|
|
||||||
bsz * copies * (tsz - self.offset + 1) * steps
|
|
||||||
- ((steps + 1) * steps // 2) * copies * bsz
|
|
||||||
)
|
|
||||||
if self.infonce:
|
|
||||||
labels = predictions.new_full(
|
|
||||||
(predictions.shape[0] // copies,), 0, dtype=torch.long
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
labels = torch.zeros_like(predictions)
|
|
||||||
weights = (
|
|
||||||
torch.full_like(labels, 1 / self.n_negatives)
|
|
||||||
if self.balanced_classes and not self.infonce
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
start = end = 0
|
|
||||||
for i in range(steps):
|
|
||||||
offset = i + self.offset
|
|
||||||
end = start + (tsz - offset) * bsz * copies
|
|
||||||
if self.infonce:
|
|
||||||
predictions[start:end] = torch.einsum(
|
|
||||||
"bct,nbct->tbn", x[..., :-offset, i], targets[..., offset:]
|
|
||||||
).flatten()
|
|
||||||
else:
|
|
||||||
pos_num = (end - start) // copies
|
|
||||||
predictions[start:end] = torch.einsum(
|
|
||||||
"bct,nbct->nbt", x[..., :-offset, i], targets[..., offset:]
|
|
||||||
).flatten()
|
|
||||||
labels[start : start + pos_num] = 1.0
|
|
||||||
if weights is not None:
|
|
||||||
weights[start : start + pos_num] = 1.0
|
|
||||||
start = end
|
|
||||||
assert end == predictions.numel(), "{} != {}".format(end, predictions.numel())
|
|
||||||
|
|
||||||
if self.infonce:
|
|
||||||
predictions = predictions.view(-1, copies)
|
|
||||||
else:
|
|
||||||
if weights is not None:
|
|
||||||
labels = (labels, weights)
|
|
||||||
|
|
||||||
return predictions, labels
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,48 +0,0 @@
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the MIT license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
try:
|
|
||||||
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
|
||||||
|
|
||||||
has_fused_layernorm = True
|
|
||||||
|
|
||||||
class FusedLayerNorm(_FusedLayerNorm):
|
|
||||||
@torch.jit.unused
|
|
||||||
def forward(self, x):
|
|
||||||
if not x.is_cuda:
|
|
||||||
return super().forward(x)
|
|
||||||
else:
|
|
||||||
with torch.cuda.device(x.device):
|
|
||||||
return super().forward(x)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
has_fused_layernorm = False
|
|
||||||
|
|
||||||
|
|
||||||
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
|
||||||
export = True
|
|
||||||
if not export and torch.cuda.is_available() and has_fused_layernorm:
|
|
||||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
|
||||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
|
||||||
|
|
||||||
|
|
||||||
class Fp32LayerNorm(nn.LayerNorm):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
output = F.layer_norm(
|
|
||||||
input.float(),
|
|
||||||
self.normalized_shape,
|
|
||||||
self.weight.float() if self.weight is not None else None,
|
|
||||||
self.bias.float() if self.bias is not None else None,
|
|
||||||
self.eps,
|
|
||||||
)
|
|
||||||
return output.type_as(input)
|
|
|
@ -1,842 +0,0 @@
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the MIT license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import collections
|
|
||||||
import contextlib
|
|
||||||
import copy
|
|
||||||
import importlib
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
from itertools import accumulate
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fairseq.modules.multihead_attention import MultiheadAttention
|
|
||||||
|
|
||||||
try:
|
|
||||||
from amp_C import multi_tensor_l2norm
|
|
||||||
|
|
||||||
multi_tensor_l2norm_available = True
|
|
||||||
except ImportError:
|
|
||||||
multi_tensor_l2norm_available = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
except ImportError:
|
|
||||||
xm = None
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
MANIFOLD_PATH_SEP = "|"
|
|
||||||
|
|
||||||
|
|
||||||
class FileContentsAction(argparse.Action):
|
|
||||||
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
|
||||||
if nargs is not None:
|
|
||||||
raise ValueError("nargs not allowed")
|
|
||||||
super(FileContentsAction, self).__init__(option_strings, dest, **kwargs)
|
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
|
||||||
from fairseq.file_io import PathManager
|
|
||||||
|
|
||||||
if PathManager.isfile(values):
|
|
||||||
with PathManager.open(values) as f:
|
|
||||||
argument = f.read().strip()
|
|
||||||
else:
|
|
||||||
argument = values
|
|
||||||
setattr(namespace, self.dest, argument)
|
|
||||||
|
|
||||||
|
|
||||||
def split_paths(paths: str, separator=os.pathsep) -> List[str]:
|
|
||||||
return (
|
|
||||||
paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
|
|
||||||
from fairseq import checkpoint_utils
|
|
||||||
|
|
||||||
deprecation_warning(
|
|
||||||
"utils.load_ensemble_for_inference is deprecated. "
|
|
||||||
"Please use checkpoint_utils.load_model_ensemble instead."
|
|
||||||
)
|
|
||||||
return checkpoint_utils.load_model_ensemble(
|
|
||||||
filenames, arg_overrides=model_arg_overrides, task=task
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_to_sample(f, sample):
|
|
||||||
if hasattr(sample, "__len__") and len(sample) == 0:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _apply(x):
|
|
||||||
if torch.is_tensor(x):
|
|
||||||
return f(x)
|
|
||||||
elif isinstance(x, collections.OrderedDict):
|
|
||||||
# OrderedDict has attributes that needs to be preserved
|
|
||||||
od = collections.OrderedDict(
|
|
||||||
(key, _apply(value)) for key, value in x.items()
|
|
||||||
)
|
|
||||||
od.__dict__ = x.__dict__
|
|
||||||
return od
|
|
||||||
elif isinstance(x, dict):
|
|
||||||
return {key: _apply(value) for key, value in x.items()}
|
|
||||||
elif isinstance(x, list):
|
|
||||||
return [_apply(x) for x in x]
|
|
||||||
elif isinstance(x, tuple):
|
|
||||||
return tuple(_apply(x) for x in x)
|
|
||||||
elif isinstance(x, set):
|
|
||||||
return {_apply(x) for x in x}
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
return _apply(sample)
|
|
||||||
|
|
||||||
|
|
||||||
def move_to_cuda(sample, device=None):
|
|
||||||
device = device or torch.cuda.current_device()
|
|
||||||
|
|
||||||
def _move_to_cuda(tensor):
|
|
||||||
# non_blocking is ignored if tensor is not pinned, so we can always set
|
|
||||||
# to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620)
|
|
||||||
return tensor.to(device=device, non_blocking=True)
|
|
||||||
|
|
||||||
return apply_to_sample(_move_to_cuda, sample)
|
|
||||||
|
|
||||||
|
|
||||||
def move_to_cpu(sample):
|
|
||||||
def _move_to_cpu(tensor):
|
|
||||||
# PyTorch has poor support for half tensors (float16) on CPU.
|
|
||||||
# Move any such tensors to float32.
|
|
||||||
if tensor.dtype in {torch.bfloat16, torch.float16}:
|
|
||||||
tensor = tensor.to(dtype=torch.float32)
|
|
||||||
return tensor.cpu()
|
|
||||||
|
|
||||||
return apply_to_sample(_move_to_cpu, sample)
|
|
||||||
|
|
||||||
|
|
||||||
def move_to_tpu(sample):
|
|
||||||
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
device = xm.xla_device()
|
|
||||||
|
|
||||||
def _move_to_tpu(tensor):
|
|
||||||
return tensor.to(device)
|
|
||||||
|
|
||||||
return apply_to_sample(_move_to_tpu, sample)
|
|
||||||
|
|
||||||
|
|
||||||
def get_incremental_state(
|
|
||||||
module: "MultiheadAttention",
|
|
||||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
||||||
key: str,
|
|
||||||
) -> Optional[Dict[str, Optional[Tensor]]]:
|
|
||||||
"""Helper for getting incremental state for an nn.Module."""
|
|
||||||
return module.get_incremental_state(incremental_state, key)
|
|
||||||
|
|
||||||
|
|
||||||
def set_incremental_state(
|
|
||||||
module: "MultiheadAttention",
|
|
||||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
||||||
key: str,
|
|
||||||
value: Dict[str, Optional[Tensor]],
|
|
||||||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
|
||||||
"""Helper for setting incremental state for an nn.Module."""
|
|
||||||
if incremental_state is not None:
|
|
||||||
result = module.set_incremental_state(incremental_state, key, value)
|
|
||||||
if result is not None:
|
|
||||||
incremental_state = result
|
|
||||||
return incremental_state
|
|
||||||
|
|
||||||
|
|
||||||
def load_align_dict(replace_unk):
|
|
||||||
if replace_unk is None:
|
|
||||||
align_dict = None
|
|
||||||
elif isinstance(replace_unk, str) and len(replace_unk) > 0:
|
|
||||||
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
|
|
||||||
align_dict = {}
|
|
||||||
with open(replace_unk, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
cols = line.split()
|
|
||||||
align_dict[cols[0]] = cols[1]
|
|
||||||
else:
|
|
||||||
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
|
|
||||||
# original source word.
|
|
||||||
align_dict = {}
|
|
||||||
return align_dict
|
|
||||||
|
|
||||||
|
|
||||||
def print_embed_overlap(embed_dict, vocab_dict):
|
|
||||||
embed_keys = set(embed_dict.keys())
|
|
||||||
vocab_keys = set(vocab_dict.symbols)
|
|
||||||
overlap = len(embed_keys & vocab_keys)
|
|
||||||
logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict)))
|
|
||||||
|
|
||||||
|
|
||||||
def parse_embedding(embed_path):
|
|
||||||
"""Parse embedding text file into a dictionary of word and embedding tensors.
|
|
||||||
|
|
||||||
The first line can have vocabulary size and dimension. The following lines
|
|
||||||
should contain word and embedding separated by spaces.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
2 5
|
|
||||||
the -0.0230 -0.0264 0.0287 0.0171 0.1403
|
|
||||||
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
|
|
||||||
"""
|
|
||||||
embed_dict = {}
|
|
||||||
with open(embed_path) as f_embed:
|
|
||||||
next(f_embed) # skip header
|
|
||||||
for line in f_embed:
|
|
||||||
pieces = line.rstrip().split(" ")
|
|
||||||
embed_dict[pieces[0]] = torch.Tensor(
|
|
||||||
[float(weight) for weight in pieces[1:]]
|
|
||||||
)
|
|
||||||
return embed_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_embedding(embed_dict, vocab, embedding):
|
|
||||||
for idx in range(len(vocab)):
|
|
||||||
token = vocab[idx]
|
|
||||||
if token in embed_dict:
|
|
||||||
embedding.weight.data[idx] = embed_dict[token]
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
|
|
||||||
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
|
|
||||||
from fairseq import tokenizer
|
|
||||||
|
|
||||||
# Tokens are strings here
|
|
||||||
hypo_tokens = tokenizer.tokenize_line(hypo_str)
|
|
||||||
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
|
|
||||||
src_tokens = tokenizer.tokenize_line(src_str) + ["<eos>"]
|
|
||||||
for i, ht in enumerate(hypo_tokens):
|
|
||||||
if ht == unk:
|
|
||||||
src_token = src_tokens[alignment[i]]
|
|
||||||
# Either take the corresponding value in the aligned dictionary or just copy the original value.
|
|
||||||
hypo_tokens[i] = align_dict.get(src_token, src_token)
|
|
||||||
return " ".join(hypo_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
def post_process_prediction(
|
|
||||||
hypo_tokens,
|
|
||||||
src_str,
|
|
||||||
alignment,
|
|
||||||
align_dict,
|
|
||||||
tgt_dict,
|
|
||||||
remove_bpe=None,
|
|
||||||
extra_symbols_to_ignore=None,
|
|
||||||
):
|
|
||||||
hypo_str = tgt_dict.string(
|
|
||||||
hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore
|
|
||||||
)
|
|
||||||
if align_dict is not None:
|
|
||||||
hypo_str = replace_unk(
|
|
||||||
hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string()
|
|
||||||
)
|
|
||||||
if align_dict is not None or remove_bpe is not None:
|
|
||||||
# Convert back to tokens for evaluating with unk replacement or without BPE
|
|
||||||
# Note that the dictionary can be modified inside the method.
|
|
||||||
hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
|
|
||||||
return hypo_tokens, hypo_str, alignment
|
|
||||||
|
|
||||||
|
|
||||||
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
|
|
||||||
"""Replace non-padding symbols with their position numbers.
|
|
||||||
|
|
||||||
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
|
||||||
"""
|
|
||||||
# The series of casts and type-conversions here are carefully
|
|
||||||
# balanced to both work with ONNX export and XLA. In particular XLA
|
|
||||||
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
|
||||||
# how to handle the dtype kwarg in cumsum.
|
|
||||||
mask = tensor.ne(padding_idx).int()
|
|
||||||
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
|
|
||||||
|
|
||||||
|
|
||||||
def strip_pad(tensor, pad):
|
|
||||||
return tensor[tensor.ne(pad)]
|
|
||||||
|
|
||||||
|
|
||||||
def buffered_arange(max):
|
|
||||||
if not hasattr(buffered_arange, "buf"):
|
|
||||||
buffered_arange.buf = torch.LongTensor()
|
|
||||||
if max > buffered_arange.buf.numel():
|
|
||||||
buffered_arange.buf.resize_(max)
|
|
||||||
torch.arange(max, out=buffered_arange.buf)
|
|
||||||
return buffered_arange.buf[:max]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_padding_direction(
|
|
||||||
src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
|
|
||||||
):
|
|
||||||
assert right_to_left ^ left_to_right
|
|
||||||
pad_mask = src_tokens.eq(padding_idx)
|
|
||||||
if not pad_mask.any():
|
|
||||||
# no padding, return early
|
|
||||||
return src_tokens
|
|
||||||
if left_to_right and not pad_mask[:, 0].any():
|
|
||||||
# already right padded
|
|
||||||
return src_tokens
|
|
||||||
if right_to_left and not pad_mask[:, -1].any():
|
|
||||||
# already left padded
|
|
||||||
return src_tokens
|
|
||||||
max_len = src_tokens.size(1)
|
|
||||||
buffered = torch.empty(0).long()
|
|
||||||
if max_len > 0:
|
|
||||||
torch.arange(max_len, out=buffered)
|
|
||||||
range = buffered.type_as(src_tokens).expand_as(src_tokens)
|
|
||||||
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
|
|
||||||
if right_to_left:
|
|
||||||
index = torch.remainder(range - num_pads, max_len)
|
|
||||||
else:
|
|
||||||
index = torch.remainder(range + num_pads, max_len)
|
|
||||||
return src_tokens.gather(1, index)
|
|
||||||
|
|
||||||
|
|
||||||
def item(tensor):
|
|
||||||
# tpu-comment: making this a no-op for xla devices.
|
|
||||||
if torch.is_tensor(tensor) and tensor.device.type == "xla":
|
|
||||||
return tensor.detach()
|
|
||||||
if hasattr(tensor, "item"):
|
|
||||||
return tensor.item()
|
|
||||||
if hasattr(tensor, "__getitem__"):
|
|
||||||
return tensor[0]
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor:
|
|
||||||
per_device_grads = {}
|
|
||||||
norms = []
|
|
||||||
for grad in grads:
|
|
||||||
device = grad.device
|
|
||||||
cur_device_grads = per_device_grads.get(device)
|
|
||||||
if cur_device_grads is None:
|
|
||||||
cur_device_grads = []
|
|
||||||
per_device_grads[device] = cur_device_grads
|
|
||||||
cur_device_grads.append(grad)
|
|
||||||
for device in per_device_grads.keys():
|
|
||||||
cur_device_grads = per_device_grads[device]
|
|
||||||
if device.type == "cuda":
|
|
||||||
# TODO(msb) return has_inf
|
|
||||||
has_inf = torch.zeros((1, 1), dtype=torch.int, device=device)
|
|
||||||
with torch.cuda.device(device):
|
|
||||||
norm = multi_tensor_l2norm(
|
|
||||||
chunk_size, has_inf, [cur_device_grads], False
|
|
||||||
)
|
|
||||||
norms.append(norm[0].to(torch.cuda.current_device()))
|
|
||||||
else:
|
|
||||||
norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads]
|
|
||||||
total_norm = torch.norm(torch.stack(norms))
|
|
||||||
return total_norm
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
|
|
||||||
def grad_exists(p):
|
|
||||||
return p is not None and getattr(p, "grad", None) is not None
|
|
||||||
|
|
||||||
if isinstance(params, torch.Tensor):
|
|
||||||
params = [params]
|
|
||||||
params = list(params)
|
|
||||||
grads = [
|
|
||||||
p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert")
|
|
||||||
]
|
|
||||||
expert_grads = [
|
|
||||||
p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert")
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(grads) == 0:
|
|
||||||
if len(params) > 0:
|
|
||||||
return params[0].new_tensor(0.0)
|
|
||||||
else:
|
|
||||||
return torch.tensor(0.0)
|
|
||||||
|
|
||||||
if len(grads) == 1:
|
|
||||||
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
if multi_tensor_l2norm_available:
|
|
||||||
total_norm = multi_tensor_total_norm(grads)
|
|
||||||
else:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
warnings.warn(
|
|
||||||
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
|
|
||||||
"you may get better performance by installing NVIDIA's apex library"
|
|
||||||
)
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
elif grads[0].device.type == "xla":
|
|
||||||
device = grads[0].device
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
total_norm = torch.norm(
|
|
||||||
torch.stack(
|
|
||||||
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if aggregate_norm_fn is not None:
|
|
||||||
total_norm = aggregate_norm_fn(total_norm)
|
|
||||||
|
|
||||||
if max_norm > 0:
|
|
||||||
max_norm = float(max_norm)
|
|
||||||
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
|
||||||
for g in grads + expert_grads:
|
|
||||||
g.mul_(clip_coef)
|
|
||||||
return total_norm
|
|
||||||
|
|
||||||
|
|
||||||
def fill_with_neg_inf(t):
|
|
||||||
"""FP16-compatible function that fills a tensor with -inf."""
|
|
||||||
return t.float().fill_(float("-inf")).type_as(t)
|
|
||||||
|
|
||||||
|
|
||||||
def _match_types(arg1, arg2):
|
|
||||||
"""Convert the numerical argument to the same type as the other argument"""
|
|
||||||
|
|
||||||
def upgrade(arg_number, arg_structure):
|
|
||||||
if isinstance(arg_structure, tuple):
|
|
||||||
return tuple([arg_number] * len(arg_structure))
|
|
||||||
elif isinstance(arg_structure, dict):
|
|
||||||
arg = copy.deepcopy(arg_structure)
|
|
||||||
for k in arg:
|
|
||||||
arg[k] = upgrade(arg_number, arg_structure[k])
|
|
||||||
return arg
|
|
||||||
else:
|
|
||||||
return arg_number
|
|
||||||
|
|
||||||
if isinstance(arg1, float) or isinstance(arg1, int):
|
|
||||||
return upgrade(arg1, arg2), arg2
|
|
||||||
elif isinstance(arg2, float) or isinstance(arg2, int):
|
|
||||||
return arg1, upgrade(arg2, arg1)
|
|
||||||
|
|
||||||
return arg1, arg2
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_max_positions(*args):
|
|
||||||
"""Resolve max position constraints from multiple sources."""
|
|
||||||
|
|
||||||
def map_value_update(d1, d2):
|
|
||||||
updated_value = copy.deepcopy(d1)
|
|
||||||
for key in d2:
|
|
||||||
if key not in updated_value:
|
|
||||||
updated_value[key] = d2[key]
|
|
||||||
else:
|
|
||||||
updated_value[key] = min(d1[key], d2[key])
|
|
||||||
return updated_value
|
|
||||||
|
|
||||||
def nullsafe_min(l):
|
|
||||||
minim = None
|
|
||||||
for item in l:
|
|
||||||
if minim is None:
|
|
||||||
minim = item
|
|
||||||
elif item is not None and item < minim:
|
|
||||||
minim = item
|
|
||||||
return minim
|
|
||||||
|
|
||||||
max_positions = None
|
|
||||||
for arg in args:
|
|
||||||
if max_positions is None:
|
|
||||||
max_positions = arg
|
|
||||||
elif arg is not None:
|
|
||||||
max_positions, arg = _match_types(max_positions, arg)
|
|
||||||
if isinstance(arg, float) or isinstance(arg, int):
|
|
||||||
max_positions = min(max_positions, arg)
|
|
||||||
elif isinstance(arg, dict):
|
|
||||||
max_positions = map_value_update(max_positions, arg)
|
|
||||||
else:
|
|
||||||
max_positions = tuple(map(nullsafe_min, zip(max_positions, arg)))
|
|
||||||
|
|
||||||
return max_positions
|
|
||||||
|
|
||||||
|
|
||||||
def import_user_module(args):
|
|
||||||
module_path = getattr(args, "user_dir", None)
|
|
||||||
if module_path is not None:
|
|
||||||
module_path = os.path.abspath(args.user_dir)
|
|
||||||
if not os.path.exists(module_path) and not os.path.isfile(
|
|
||||||
os.path.dirname(module_path)
|
|
||||||
):
|
|
||||||
fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir)
|
|
||||||
if os.path.exists(fairseq_rel_path):
|
|
||||||
module_path = fairseq_rel_path
|
|
||||||
else:
|
|
||||||
fairseq_rel_path = os.path.join(
|
|
||||||
os.path.dirname(__file__), "..", args.user_dir
|
|
||||||
)
|
|
||||||
if os.path.exists(fairseq_rel_path):
|
|
||||||
module_path = fairseq_rel_path
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(module_path)
|
|
||||||
|
|
||||||
# ensure that user modules are only imported once
|
|
||||||
import_user_module.memo = getattr(import_user_module, "memo", set())
|
|
||||||
if module_path not in import_user_module.memo:
|
|
||||||
import_user_module.memo.add(module_path)
|
|
||||||
|
|
||||||
module_parent, module_name = os.path.split(module_path)
|
|
||||||
if module_name not in sys.modules:
|
|
||||||
sys.path.insert(0, module_parent)
|
|
||||||
importlib.import_module(module_name)
|
|
||||||
|
|
||||||
tasks_path = os.path.join(module_path, "tasks")
|
|
||||||
if os.path.exists(tasks_path):
|
|
||||||
from fairseq.tasks import import_tasks
|
|
||||||
|
|
||||||
import_tasks(tasks_path, f"{module_name}.tasks")
|
|
||||||
|
|
||||||
models_path = os.path.join(module_path, "models")
|
|
||||||
if os.path.exists(models_path):
|
|
||||||
from fairseq.models import import_models
|
|
||||||
|
|
||||||
import_models(models_path, f"{module_name}.models")
|
|
||||||
elif module_path in sys.modules[module_name].__path__:
|
|
||||||
logger.info(f"--user-dir={module_path} has already been imported.")
|
|
||||||
else:
|
|
||||||
raise ImportError(
|
|
||||||
"Failed to import --user-dir={} because the corresponding module name "
|
|
||||||
"({}) is not globally unique. Please rename the directory to "
|
|
||||||
"something unique and try again.".format(module_path, module_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def softmax(x, dim: int, onnx_trace: bool = False):
|
|
||||||
if onnx_trace:
|
|
||||||
return F.softmax(x.float(), dim=dim)
|
|
||||||
else:
|
|
||||||
return F.softmax(x, dim=dim, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def log_softmax(x, dim: int, onnx_trace: bool = False):
|
|
||||||
if onnx_trace:
|
|
||||||
return F.log_softmax(x.float(), dim=dim)
|
|
||||||
else:
|
|
||||||
return F.log_softmax(x, dim=dim, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def get_perplexity(loss, round=2, base=2):
|
|
||||||
from fairseq.logging.meters import safe_round
|
|
||||||
|
|
||||||
if loss is None:
|
|
||||||
return 0.0
|
|
||||||
try:
|
|
||||||
return safe_round(base**loss, round)
|
|
||||||
except OverflowError:
|
|
||||||
return float("inf")
|
|
||||||
|
|
||||||
|
|
||||||
def deprecation_warning(message, stacklevel=3):
|
|
||||||
# don't use DeprecationWarning, since it's ignored by default
|
|
||||||
warnings.warn(message, stacklevel=stacklevel)
|
|
||||||
|
|
||||||
|
|
||||||
def relu_squared(x: torch.Tensor):
|
|
||||||
return F.relu(x).pow(2)
|
|
||||||
|
|
||||||
|
|
||||||
def get_activation_fn(activation: str) -> Callable:
|
|
||||||
"""Returns the activation function corresponding to `activation`"""
|
|
||||||
from fairseq.modules import gelu, gelu_accurate
|
|
||||||
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
elif activation == "relu_squared":
|
|
||||||
return relu_squared
|
|
||||||
elif activation == "gelu":
|
|
||||||
return gelu
|
|
||||||
elif activation == "gelu_fast":
|
|
||||||
deprecation_warning(
|
|
||||||
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
|
||||||
)
|
|
||||||
return gelu_accurate
|
|
||||||
elif activation == "gelu_accurate":
|
|
||||||
return gelu_accurate
|
|
||||||
elif activation == "tanh":
|
|
||||||
return torch.tanh
|
|
||||||
elif activation == "linear":
|
|
||||||
return lambda x: x
|
|
||||||
elif activation == "swish":
|
|
||||||
return torch.nn.SiLU
|
|
||||||
else:
|
|
||||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_activation_fns() -> List:
|
|
||||||
return [
|
|
||||||
"relu",
|
|
||||||
"gelu",
|
|
||||||
"gelu_fast", # deprecated
|
|
||||||
"gelu_accurate",
|
|
||||||
"tanh",
|
|
||||||
"linear",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def model_eval(model):
|
|
||||||
is_training = model.training
|
|
||||||
model.eval()
|
|
||||||
yield
|
|
||||||
model.train(is_training)
|
|
||||||
|
|
||||||
|
|
||||||
def has_parameters(module):
|
|
||||||
try:
|
|
||||||
next(module.parameters())
|
|
||||||
return True
|
|
||||||
except StopIteration:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_rng_state():
|
|
||||||
state = {"torch_rng_state": torch.get_rng_state()}
|
|
||||||
if xm is not None:
|
|
||||||
state["xla_rng_state"] = xm.get_rng_state()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def set_rng_state(state):
|
|
||||||
torch.set_rng_state(state["torch_rng_state"])
|
|
||||||
if xm is not None:
|
|
||||||
xm.set_rng_state(state["xla_rng_state"])
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
|
||||||
|
|
||||||
|
|
||||||
class set_torch_seed(object):
|
|
||||||
def __init__(self, seed):
|
|
||||||
assert isinstance(seed, int)
|
|
||||||
self.rng_state = get_rng_state()
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if xm is not None:
|
|
||||||
xm.set_rng_state(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
|
||||||
set_rng_state(self.rng_state)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_alignment(line):
|
|
||||||
"""
|
|
||||||
Parses a single line from the alingment file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
line (str): String containing the alignment of the format:
|
|
||||||
<src_idx_1>-<tgt_idx_1> <src_idx_2>-<tgt_idx_2> ..
|
|
||||||
<src_idx_m>-<tgt_idx_m>. All indices are 0 indexed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.IntTensor: packed alignments of shape (2 * m).
|
|
||||||
"""
|
|
||||||
alignments = line.strip().split()
|
|
||||||
parsed_alignment = torch.IntTensor(2 * len(alignments))
|
|
||||||
for idx, alignment in enumerate(alignments):
|
|
||||||
src_idx, tgt_idx = alignment.split("-")
|
|
||||||
parsed_alignment[2 * idx] = int(src_idx)
|
|
||||||
parsed_alignment[2 * idx + 1] = int(tgt_idx)
|
|
||||||
return parsed_alignment
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_to_word_mapping(tokens, exclude_list):
|
|
||||||
n = len(tokens)
|
|
||||||
word_start = [int(token not in exclude_list) for token in tokens]
|
|
||||||
word_idx = list(accumulate(word_start))
|
|
||||||
token_to_word = {i: word_idx[i] for i in range(n)}
|
|
||||||
return token_to_word
|
|
||||||
|
|
||||||
|
|
||||||
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
|
|
||||||
tgt_valid = (
|
|
||||||
((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)
|
|
||||||
)
|
|
||||||
src_invalid = (
|
|
||||||
((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)
|
|
||||||
)
|
|
||||||
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad])
|
|
||||||
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad])
|
|
||||||
alignment = []
|
|
||||||
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent):
|
|
||||||
attn_valid = attn[tgt_valid]
|
|
||||||
attn_valid[:, src_invalid] = float("-inf")
|
|
||||||
_, src_indices = attn_valid.max(dim=1)
|
|
||||||
for tgt_idx, src_idx in zip(tgt_valid, src_indices):
|
|
||||||
alignment.append(
|
|
||||||
(
|
|
||||||
src_token_to_word[src_idx.item()] - 1,
|
|
||||||
tgt_token_to_word[tgt_idx.item()] - 1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return alignment
|
|
||||||
|
|
||||||
|
|
||||||
def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos):
|
|
||||||
tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False)
|
|
||||||
src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1)
|
|
||||||
alignment = []
|
|
||||||
if len(tgt_valid) != 0 and len(src_valid) != 0:
|
|
||||||
attn_valid = attn[tgt_valid, src_valid]
|
|
||||||
alignment = [
|
|
||||||
["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid
|
|
||||||
]
|
|
||||||
return alignment
|
|
||||||
|
|
||||||
|
|
||||||
def new_arange(x, *size):
|
|
||||||
"""
|
|
||||||
Return a Tensor of `size` filled with a range function on the device of x.
|
|
||||||
If size is empty, using the size of the variable x.
|
|
||||||
"""
|
|
||||||
if len(size) == 0:
|
|
||||||
size = x.size()
|
|
||||||
return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
def get_tpu_device():
|
|
||||||
return xm.xla_device()
|
|
||||||
|
|
||||||
|
|
||||||
def tpu_data_loader(itr):
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
import torch_xla.distributed.parallel_loader as pl
|
|
||||||
|
|
||||||
from fairseq.data import iterators
|
|
||||||
|
|
||||||
xm.rendezvous("tpu_data_loader") # wait for all workers
|
|
||||||
xm.mark_step()
|
|
||||||
device = xm.xla_device()
|
|
||||||
return iterators.CountingIterator(
|
|
||||||
pl.ParallelLoader(itr, [device]).per_device_loader(device),
|
|
||||||
start=getattr(itr, "n", 0),
|
|
||||||
total=len(itr),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_xla_tensor(tensor):
|
|
||||||
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
|
||||||
|
|
||||||
|
|
||||||
def index_put(tensor, indices, value):
|
|
||||||
if is_xla_tensor(tensor):
|
|
||||||
for _ in range(indices.dim(), tensor.dim()):
|
|
||||||
indices = indices.unsqueeze(-1)
|
|
||||||
if indices.size(-1) < tensor.size(-1):
|
|
||||||
indices = indices.expand_as(tensor)
|
|
||||||
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
|
||||||
else:
|
|
||||||
tensor[indices] = value
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def xla_device_to_cpu(dat):
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
return xm._maybe_convert_to_cpu(dat)
|
|
||||||
|
|
||||||
|
|
||||||
class CudaEnvironment(object):
|
|
||||||
def __init__(self):
|
|
||||||
cur_device = torch.cuda.current_device()
|
|
||||||
prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device))
|
|
||||||
self.name = prop.name
|
|
||||||
self.major = prop.major
|
|
||||||
self.minor = prop.minor
|
|
||||||
self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def pretty_print_cuda_env_list(cuda_env_list):
|
|
||||||
"""
|
|
||||||
Given a list of CudaEnviorments, pretty print them
|
|
||||||
"""
|
|
||||||
num_workers = len(cuda_env_list)
|
|
||||||
center = "CUDA enviroments for all {} workers".format(num_workers)
|
|
||||||
banner_len = 40 - len(center) // 2
|
|
||||||
first_line = "*" * banner_len + center + "*" * banner_len
|
|
||||||
logger.info(first_line)
|
|
||||||
for r, env in enumerate(cuda_env_list):
|
|
||||||
logger.info(
|
|
||||||
"rank {:3d}: ".format(r)
|
|
||||||
+ "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor)
|
|
||||||
+ "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB)
|
|
||||||
+ "name = {:40s}".format(env.name)
|
|
||||||
)
|
|
||||||
logger.info(first_line)
|
|
||||||
|
|
||||||
|
|
||||||
def csv_str_list(x):
|
|
||||||
return x.split(",")
|
|
||||||
|
|
||||||
|
|
||||||
def eval_str_list(x, type=float):
|
|
||||||
if x is None:
|
|
||||||
return None
|
|
||||||
if isinstance(x, str):
|
|
||||||
x = eval(x)
|
|
||||||
try:
|
|
||||||
return list(map(type, x))
|
|
||||||
except TypeError:
|
|
||||||
return [type(x)]
|
|
||||||
|
|
||||||
|
|
||||||
def eval_str_dict(x, type=dict):
|
|
||||||
if x is None:
|
|
||||||
return None
|
|
||||||
if isinstance(x, str):
|
|
||||||
x = eval(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def eval_bool(x, default=False):
|
|
||||||
if x is None:
|
|
||||||
return default
|
|
||||||
try:
|
|
||||||
return bool(eval(x))
|
|
||||||
except TypeError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def reset_logging():
|
|
||||||
root = logging.getLogger()
|
|
||||||
for handler in root.handlers:
|
|
||||||
root.removeHandler(handler)
|
|
||||||
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
|
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
|
||||||
handler.setFormatter(
|
|
||||||
logging.Formatter(
|
|
||||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
root.addHandler(handler)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_getattr(obj, k, default=None):
|
|
||||||
"""Returns obj[k] if it exists and is not None, otherwise returns default."""
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
if OmegaConf.is_config(obj):
|
|
||||||
return obj[k] if k in obj and obj[k] is not None else default
|
|
||||||
|
|
||||||
return getattr(obj, k, default)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_hasattr(obj, k):
|
|
||||||
"""Returns True if the given key exists and is not None."""
|
|
||||||
return getattr(obj, k, None) is not None
|
|
Loading…
Reference in New Issue