Delete fairseq_onnx directory
This commit is contained in:
parent
cf1724c1eb
commit
120444b293
|
@ -1,489 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data.data_utils import compute_mask_indices
|
||||
from fairseq.data.dictionary import Dictionary
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
from fairseq.models.wav2vec.wav2vec2 import (
|
||||
EXTRACTOR_MODE_CHOICES,
|
||||
MASKING_DISTRIBUTION_CHOICES,
|
||||
LAYER_TYPE_CHOICES,
|
||||
ConvFeatureExtractionModel,
|
||||
TransformerEncoder,
|
||||
)
|
||||
from fairseq.modules import LayerNorm
|
||||
from fairseq.tasks.hubert_pretraining import (
|
||||
HubertPretrainingConfig,
|
||||
HubertPretrainingTask,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubertConfig(FairseqDataclass):
|
||||
label_rate: float = II("task.label_rate")
|
||||
|
||||
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
||||
default="default",
|
||||
metadata={
|
||||
"help": "mode for feature extractor. default has a single group "
|
||||
"norm with d groups in the first conv block, whereas layer_norm "
|
||||
"has layer norms in every block (meant to use with normalize=True)"
|
||||
},
|
||||
)
|
||||
encoder_layers: int = field(
|
||||
default=12, metadata={"help": "num encoder layers in the transformer"}
|
||||
)
|
||||
encoder_embed_dim: int = field(
|
||||
default=768, metadata={"help": "encoder embedding dimension"}
|
||||
)
|
||||
encoder_ffn_embed_dim: int = field(
|
||||
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
|
||||
)
|
||||
encoder_attention_heads: int = field(
|
||||
default=12, metadata={"help": "num encoder attention heads"}
|
||||
)
|
||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||
default="gelu", metadata={"help": "activation function to use"}
|
||||
)
|
||||
layer_type: LAYER_TYPE_CHOICES = field(
|
||||
default="transformer", metadata={"help": "layer type in encoder"}
|
||||
)
|
||||
|
||||
# dropouts
|
||||
dropout: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "dropout probability for the transformer"},
|
||||
)
|
||||
attention_dropout: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "dropout probability for attention weights"},
|
||||
)
|
||||
activation_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout probability after activation in FFN"},
|
||||
)
|
||||
encoder_layerdrop: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "probability of dropping a tarnsformer layer"},
|
||||
)
|
||||
dropout_input: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
||||
)
|
||||
dropout_features: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout to apply to the features (after feat extr)"},
|
||||
)
|
||||
|
||||
final_dim: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "project final representations and targets to this many "
|
||||
"dimensions. set to encoder_embed_dim is <= 0"
|
||||
},
|
||||
)
|
||||
untie_final_proj: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "use separate projection for each target"},
|
||||
)
|
||||
layer_norm_first: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "apply layernorm first in the transformer"},
|
||||
)
|
||||
conv_feature_layers: str = field(
|
||||
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
||||
metadata={
|
||||
"help": "string describing convolutional feature extraction "
|
||||
"layers in form of a python list that contains "
|
||||
"[(dim, kernel_size, stride), ...]"
|
||||
},
|
||||
)
|
||||
conv_bias: bool = field(
|
||||
default=False, metadata={"help": "include bias in conv encoder"}
|
||||
)
|
||||
logit_temp: float = field(
|
||||
default=0.1, metadata={"help": "temperature to divide logits by"}
|
||||
)
|
||||
target_glu: bool = field(
|
||||
default=False, metadata={"help": "adds projection + glu to targets"}
|
||||
)
|
||||
feature_grad_mult: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "multiply feature extractor var grads by this"},
|
||||
)
|
||||
|
||||
# masking
|
||||
mask_length: int = field(default=10, metadata={"help": "mask length"})
|
||||
mask_prob: float = field(
|
||||
default=0.65,
|
||||
metadata={"help": "probability of replacing a token with mask"},
|
||||
)
|
||||
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static", metadata={"help": "how to choose mask length"}
|
||||
)
|
||||
mask_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "secondary mask argument "
|
||||
"(used for more complex distributions), "
|
||||
"see help in compute_mask_indicesh"
|
||||
},
|
||||
)
|
||||
no_mask_overlap: bool = field(
|
||||
default=False, metadata={"help": "whether to allow masks to overlap"}
|
||||
)
|
||||
mask_min_space: int = field(
|
||||
default=1,
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
|
||||
# channel masking
|
||||
mask_channel_length: int = field(
|
||||
default=10,
|
||||
metadata={"help": "length of the mask for features (channels)"},
|
||||
)
|
||||
mask_channel_prob: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "probability of replacing a feature with 0"},
|
||||
)
|
||||
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static",
|
||||
metadata={"help": "how to choose mask length for channel masking"},
|
||||
)
|
||||
mask_channel_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "secondary mask argument "
|
||||
"(used for more complex distributions), "
|
||||
"see help in compute_mask_indicesh"
|
||||
},
|
||||
)
|
||||
no_mask_channel_overlap: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to allow channel masks to overlap"},
|
||||
)
|
||||
mask_channel_min_space: int = field(
|
||||
default=1,
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
|
||||
# positional embeddings
|
||||
conv_pos: int = field(
|
||||
default=128,
|
||||
metadata={"help": "number of filters for convolutional positional embeddings"},
|
||||
)
|
||||
conv_pos_groups: int = field(
|
||||
default=16,
|
||||
metadata={"help": "number of groups for convolutional positional embedding"},
|
||||
)
|
||||
|
||||
latent_temp: Tuple[float, float, float] = field(
|
||||
default=(2, 0.5, 0.999995),
|
||||
metadata={"help": "legacy (to be removed)"},
|
||||
)
|
||||
|
||||
# loss computation
|
||||
skip_masked: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "skip computing losses over masked frames"},
|
||||
)
|
||||
skip_nomask: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "skip computing losses over unmasked frames"},
|
||||
)
|
||||
|
||||
checkpoint_activations: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "recompute activations and save memory for extra compute"},
|
||||
)
|
||||
|
||||
# FP16 optimization
|
||||
required_seq_len_multiple: int = field(
|
||||
default=2,
|
||||
metadata={
|
||||
"help": "pad the input to encoder such that the sequence length is divisible by multiple"
|
||||
},
|
||||
)
|
||||
|
||||
# Conformer
|
||||
depthwise_conv_kernel_size: int = field(
|
||||
default=31,
|
||||
metadata={
|
||||
"help": "depthwise-conv-kernel-size for convolution in conformer layer"
|
||||
},
|
||||
)
|
||||
attn_type: str = field(
|
||||
default="",
|
||||
metadata={"help": "if espnet use ESPNET MHA"},
|
||||
)
|
||||
pos_enc_type: str = field(
|
||||
default="abs",
|
||||
metadata={"help": "Positional encoding type to use in conformer"},
|
||||
)
|
||||
fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
|
||||
|
||||
|
||||
@register_model("hubert", dataclass=HubertConfig)
|
||||
class HubertModel(BaseFairseqModel):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: HubertConfig,
|
||||
task_cfg: HubertPretrainingConfig,
|
||||
dictionaries: List[Dictionary],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logger.info(f"HubertModel Config: {cfg}")
|
||||
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
|
||||
self.embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
||||
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
if self.embed != cfg.encoder_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
self.logit_temp = cfg.logit_temp
|
||||
self.skip_masked = cfg.skip_masked
|
||||
self.skip_nomask = cfg.skip_nomask
|
||||
|
||||
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
||||
|
||||
self.mask_emb = nn.Parameter(
|
||||
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
||||
)
|
||||
|
||||
self.encoder = TransformerEncoder(cfg)
|
||||
self.layer_norm = LayerNorm(self.embed)
|
||||
|
||||
self.target_glu = None
|
||||
if cfg.target_glu:
|
||||
self.target_glu = nn.Sequential(
|
||||
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
||||
)
|
||||
|
||||
self.untie_final_proj = cfg.untie_final_proj
|
||||
if self.untie_final_proj:
|
||||
self.final_proj = nn.Linear(
|
||||
cfg.encoder_embed_dim, final_dim * len(dictionaries)
|
||||
)
|
||||
else:
|
||||
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
||||
|
||||
# modules below are not needed during fine-tuning
|
||||
if any([d is None for d in dictionaries]):
|
||||
logger.info("cannot find dictionary. assume will be used for fine-tuning")
|
||||
else:
|
||||
self.num_classes = [len(d) for d in dictionaries]
|
||||
self.label_embs_concat = nn.Parameter(
|
||||
torch.FloatTensor(sum(self.num_classes), final_dim)
|
||||
)
|
||||
nn.init.uniform_(self.label_embs_concat)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: HubertConfig, task: HubertPretrainingTask):
|
||||
"""Build a new model instance."""
|
||||
|
||||
model = HubertModel(cfg, task.cfg, task.dictionaries)
|
||||
return model
|
||||
|
||||
def apply_mask(self, x, padding_mask, target_list):
|
||||
B, T, C = x.shape
|
||||
if self.mask_prob > 0:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=2,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
x[mask_indices] = self.mask_emb
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def compute_nce(self, x, pos, negs):
|
||||
neg_is_pos = (pos == negs).all(-1)
|
||||
pos = pos.unsqueeze(0)
|
||||
targets = torch.cat([pos, negs], dim=0)
|
||||
|
||||
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
||||
logits /= self.logit_temp
|
||||
if neg_is_pos.any():
|
||||
logits[1:][neg_is_pos] = float("-inf")
|
||||
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
||||
return logits
|
||||
|
||||
def forward_features(self, source: torch.Tensor) -> torch.Tensor:
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(source)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(source)
|
||||
return features
|
||||
|
||||
def forward_targets(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
target_list: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Trim features to ensure labels exist and then get aligned labels
|
||||
feat_tsz = features.size(2)
|
||||
targ_tsz = min([t.size(1) for t in target_list])
|
||||
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
||||
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
||||
features = features[..., :feat_tsz]
|
||||
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
||||
target_list = [t[:, target_inds.long()] for t in target_list]
|
||||
return features, target_list
|
||||
|
||||
def forward_padding_mask(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
source = source.squeeze(0)
|
||||
output_layer = 9
|
||||
features = self.forward_features(source)
|
||||
features = features.transpose(1, 2)
|
||||
features = self.layer_norm(features)
|
||||
padding_mask = torch.zeros(size=(1,features.shape[1]), dtype=torch.bool)
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
features = self.dropout_input(features)
|
||||
x = features
|
||||
x = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
layer=None if output_layer is None else output_layer - 1,
|
||||
)
|
||||
return self.final_proj(x)
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
source: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
mask: bool = False,
|
||||
ret_conv: bool = False,
|
||||
output_layer: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
res = self.forward(
|
||||
source,
|
||||
padding_mask=padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
output_layer=output_layer,
|
||||
)
|
||||
feature = res["features"] if ret_conv else res["x"]
|
||||
return feature, res["padding_mask"]
|
||||
|
||||
def get_logits(self, net_output, is_masked=True):
|
||||
if is_masked:
|
||||
logits_list = net_output["logit_m_list"]
|
||||
else:
|
||||
logits_list = net_output["logit_u_list"]
|
||||
logits_list = [x.float() for x in logits_list if x is not None]
|
||||
return logits_list
|
||||
|
||||
def get_targets(self, net_output, is_masked=True):
|
||||
logits_list = self.get_logits(net_output, is_masked)
|
||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
||||
return targets_list
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
extra_losses = []
|
||||
names = []
|
||||
|
||||
if "features_pen" in net_output:
|
||||
extra_losses.append(net_output["features_pen"])
|
||||
names.append("features_pen")
|
||||
|
||||
return extra_losses, names
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.target_glu = None
|
||||
self.final_proj = None
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
||||
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
||||
if x is None:
|
||||
return None, 0
|
||||
tsz = x.size(dim)
|
||||
m = tsz / multiple
|
||||
remainder = math.ceil(m) * multiple - tsz
|
||||
if m.is_integer():
|
||||
return x, 0
|
||||
pad_offset = (0,) * (-1 - dim) * 2
|
||||
|
||||
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
|
@ -1,630 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from omegaconf import II
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
from fairseq.modules import (
|
||||
Fp32GroupNorm,
|
||||
Fp32LayerNorm,
|
||||
GumbelVectorQuantizer,
|
||||
KmeansVectorQuantizer,
|
||||
TransposeLast,
|
||||
)
|
||||
from fairseq.tasks import FairseqTask
|
||||
from fairseq.utils import buffered_arange
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"])
|
||||
PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"])
|
||||
ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"])
|
||||
VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2VecConfig(FairseqDataclass):
|
||||
prediction_steps: int = field(
|
||||
default=12, metadata={"help": "number of steps ahead to predict"}
|
||||
)
|
||||
sample_distance: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "sample distance from target. does not work properly with cross-sampling"
|
||||
},
|
||||
)
|
||||
cross_sample_negatives: int = field(
|
||||
default=0, metadata={"help": "num of cross sampled negatives"}
|
||||
)
|
||||
num_negatives: int = field(
|
||||
default=10, metadata={"help": "num of sampled negatives"}
|
||||
)
|
||||
conv_feature_layers: str = field(
|
||||
default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]",
|
||||
metadata={
|
||||
"help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]"
|
||||
},
|
||||
)
|
||||
conv_aggregator_layers: str = field(
|
||||
default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]",
|
||||
metadata={
|
||||
"help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]"
|
||||
},
|
||||
)
|
||||
dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout to apply within the model"}
|
||||
)
|
||||
dropout_features: float = field(
|
||||
default=0.0, metadata={"help": "dropout to apply to the features"}
|
||||
)
|
||||
dropout_agg: float = field(
|
||||
default=0.0, metadata={"help": "dropout to apply after aggregation step"}
|
||||
)
|
||||
aggregator: AGGREGATOR_CHOICES = field(
|
||||
default="cnn", metadata={"help": "type of aggregator to use"}
|
||||
)
|
||||
gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"})
|
||||
no_conv_bias: bool = field(
|
||||
default=False, metadata={"help": "if set, does not learn bias for conv layers"}
|
||||
)
|
||||
agg_zero_pad: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if set, zero pads in aggregator instead of repl pad"},
|
||||
)
|
||||
skip_connections_feat: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if set, adds skip connections to the feature extractor"},
|
||||
)
|
||||
skip_connections_agg: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "if set, adds skip connections to the aggregator"},
|
||||
)
|
||||
residual_scale: float = field(
|
||||
default=0.5, metadata={"help": "scales residual by sqrt(value)"}
|
||||
)
|
||||
log_compression: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "if set, adds a log compression to feature extractor"},
|
||||
)
|
||||
balanced_classes: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if set, loss is scaled to balance for number of negatives"},
|
||||
)
|
||||
project_features: PROJECT_FEATURES_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, features are projected using the (same or new) aggregator"
|
||||
},
|
||||
)
|
||||
non_affine_group_norm: bool = field(
|
||||
default=False, metadata={"help": "if set, group norm is not affine"}
|
||||
)
|
||||
offset: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value"
|
||||
},
|
||||
)
|
||||
activation: ACTIVATION_CHOICES = field(
|
||||
default="relu",
|
||||
metadata={
|
||||
"help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value"
|
||||
},
|
||||
)
|
||||
vq_type: VQ_TYPE_CHOICES = field(
|
||||
default="none", metadata={"help": "which type of quantizer to use"}
|
||||
)
|
||||
vq_vars: int = field(
|
||||
default=320,
|
||||
metadata={"help": "project to this many vector quantized variables per group"},
|
||||
)
|
||||
vq_groups: int = field(
|
||||
default=2, metadata={"help": "number of groups of latent variables"}
|
||||
)
|
||||
vq_dim: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups"
|
||||
},
|
||||
)
|
||||
vq_depth: int = field(
|
||||
default=1, metadata={"help": "number of layers for vq weight projection"}
|
||||
)
|
||||
combine_groups: bool = field(
|
||||
default=False, metadata={"help": "if set, variables are shared among groups"}
|
||||
)
|
||||
vq_temp: Tuple[float, float, float] = field(
|
||||
default=(2.0, 0.5, 0.999995),
|
||||
metadata={
|
||||
"help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)"
|
||||
},
|
||||
)
|
||||
vq_gamma: float = field(
|
||||
default=0.25,
|
||||
metadata={"help": "gamma parameter for kmeans style vector quantization"},
|
||||
)
|
||||
infonce: bool = II("criterion.infonce")
|
||||
|
||||
|
||||
@register_model("wav2vec", dataclass=Wav2VecConfig)
|
||||
class Wav2VecModel(BaseFairseqModel):
|
||||
@classmethod
|
||||
def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask):
|
||||
"""Build a new model instance."""
|
||||
|
||||
model = Wav2VecModel(cfg)
|
||||
logger.info(model)
|
||||
return model
|
||||
|
||||
def __init__(self, cfg: Wav2VecConfig):
|
||||
super().__init__()
|
||||
|
||||
self.prediction_steps = cfg.prediction_steps
|
||||
offset = cfg.offset
|
||||
|
||||
if cfg.activation == "relu":
|
||||
activation = nn.ReLU()
|
||||
elif cfg.activation == "gelu":
|
||||
activation = nn.GELU()
|
||||
else:
|
||||
raise Exception("unknown activation " + cfg.activation)
|
||||
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
log_compression=cfg.log_compression,
|
||||
skip_connections=cfg.skip_connections_feat,
|
||||
residual_scale=cfg.residual_scale,
|
||||
non_affine_group_norm=cfg.non_affine_group_norm,
|
||||
activation=activation,
|
||||
)
|
||||
embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.vector_quantizer = None
|
||||
if cfg.vq_type == "gumbel":
|
||||
self.vector_quantizer = GumbelVectorQuantizer(
|
||||
dim=embed,
|
||||
num_vars=cfg.vq_vars,
|
||||
temp=cfg.vq_temp,
|
||||
groups=cfg.vq_groups,
|
||||
combine_groups=cfg.combine_groups,
|
||||
vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
|
||||
time_first=False,
|
||||
activation=activation,
|
||||
weight_proj_depth=cfg.vq_depth,
|
||||
weight_proj_factor=2,
|
||||
)
|
||||
elif cfg.vq_type == "kmeans":
|
||||
self.vector_quantizer = KmeansVectorQuantizer(
|
||||
dim=embed,
|
||||
num_vars=cfg.vq_vars,
|
||||
groups=cfg.vq_groups,
|
||||
combine_groups=cfg.combine_groups,
|
||||
vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
|
||||
time_first=False,
|
||||
gamma=cfg.vq_gamma,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
cfg.vq_type == "none" or cfg.vq_type is None
|
||||
), "Unknown quantizer type"
|
||||
|
||||
if cfg.offset == "auto":
|
||||
jin = 0
|
||||
rin = 0
|
||||
for _, k, stride in feature_enc_layers:
|
||||
if rin == 0:
|
||||
rin = k
|
||||
rin = rin + (k - 1) * jin
|
||||
if jin == 0:
|
||||
jin = stride
|
||||
else:
|
||||
jin *= stride
|
||||
offset = math.ceil(rin / jin)
|
||||
|
||||
offset = int(offset)
|
||||
|
||||
def make_aggregator():
|
||||
if cfg.aggregator == "cnn":
|
||||
agg_layers = eval(cfg.conv_aggregator_layers)
|
||||
agg_dim = agg_layers[-1][0]
|
||||
feature_aggregator = ConvAggegator(
|
||||
conv_layers=agg_layers,
|
||||
embed=embed,
|
||||
dropout=cfg.dropout,
|
||||
skip_connections=cfg.skip_connections_agg,
|
||||
residual_scale=cfg.residual_scale,
|
||||
non_affine_group_norm=cfg.non_affine_group_norm,
|
||||
conv_bias=not cfg.no_conv_bias,
|
||||
zero_pad=cfg.agg_zero_pad,
|
||||
activation=activation,
|
||||
)
|
||||
elif cfg.aggregator == "gru":
|
||||
agg_dim = cfg.gru_dim
|
||||
feature_aggregator = nn.Sequential(
|
||||
TransposeLast(),
|
||||
nn.GRU(
|
||||
input_size=embed,
|
||||
hidden_size=agg_dim,
|
||||
num_layers=1,
|
||||
dropout=cfg.dropout,
|
||||
),
|
||||
TransposeLast(deconstruct_idx=0),
|
||||
)
|
||||
else:
|
||||
raise Exception("unknown aggregator type " + cfg.aggregator)
|
||||
|
||||
return feature_aggregator, agg_dim
|
||||
|
||||
self.feature_aggregator, agg_dim = make_aggregator()
|
||||
|
||||
self.wav2vec_predictions = Wav2VecPredictionsModel(
|
||||
in_dim=agg_dim,
|
||||
out_dim=embed,
|
||||
prediction_steps=cfg.prediction_steps,
|
||||
n_negatives=cfg.num_negatives,
|
||||
cross_sample_negatives=cfg.cross_sample_negatives,
|
||||
sample_distance=cfg.sample_distance,
|
||||
dropout=cfg.dropout,
|
||||
offset=offset,
|
||||
balanced_classes=cfg.balanced_classes,
|
||||
infonce=cfg.infonce,
|
||||
)
|
||||
|
||||
self.dropout_feats = nn.Dropout(p=cfg.dropout_features)
|
||||
self.dropout_agg = nn.Dropout(p=cfg.dropout_agg)
|
||||
|
||||
if cfg.project_features == "none":
|
||||
self.project_features = None
|
||||
elif cfg.project_features == "same":
|
||||
self.project_features = self.feature_aggregator
|
||||
elif cfg.project_features == "new":
|
||||
self.project_features, _ = make_aggregator()
|
||||
|
||||
def forward(self, source):
|
||||
result = {}
|
||||
|
||||
features = self.feature_extractor(source)
|
||||
if self.vector_quantizer:
|
||||
q_res = self.vector_quantizer(features)
|
||||
features = q_res["x"]
|
||||
for k in q_res.keys():
|
||||
if k != "x":
|
||||
result[k] = q_res[k]
|
||||
|
||||
x = self.dropout_feats(features)
|
||||
x = self.feature_aggregator(x)
|
||||
x = self.dropout_agg(x)
|
||||
|
||||
if self.project_features is not None:
|
||||
features = self.project_features(features)
|
||||
x, targets = self.wav2vec_predictions(x, features)
|
||||
result["cpc_logits"] = x
|
||||
result["cpc_targets"] = targets
|
||||
|
||||
return result
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum length supported by the model."""
|
||||
return sys.maxsize
|
||||
|
||||
def get_logits(self, net_output):
|
||||
logits = net_output["cpc_logits"]
|
||||
return logits
|
||||
|
||||
def get_targets(self, sample, net_output):
|
||||
t = net_output["cpc_targets"]
|
||||
if isinstance(t, tuple):
|
||||
t = t[0]
|
||||
return t.contiguous()
|
||||
|
||||
def get_target_weights(self, targets, net_output):
|
||||
targets = net_output["cpc_targets"]
|
||||
if isinstance(targets, tuple) and targets[-1] is not None:
|
||||
return targets[-1]
|
||||
return None
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
loss = None
|
||||
if "prob_perplexity" in net_output:
|
||||
loss = net_output["num_vars"] - net_output["prob_perplexity"]
|
||||
elif "kmeans_loss" in net_output:
|
||||
loss = net_output["kmeans_loss"]
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def norm_block(is_layer_norm, dim, affine=True):
|
||||
if is_layer_norm:
|
||||
mod = nn.Sequential(
|
||||
TransposeLast(),
|
||||
Fp32LayerNorm(dim, elementwise_affine=affine),
|
||||
TransposeLast(),
|
||||
)
|
||||
else:
|
||||
mod = Fp32GroupNorm(1, dim, affine=affine)
|
||||
|
||||
return mod
|
||||
|
||||
|
||||
class ConvFeatureExtractionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_layers,
|
||||
dropout,
|
||||
log_compression,
|
||||
skip_connections,
|
||||
residual_scale,
|
||||
non_affine_group_norm,
|
||||
activation,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def block(n_in, n_out, k, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
|
||||
nn.Dropout(p=dropout),
|
||||
norm_block(
|
||||
is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm
|
||||
),
|
||||
activation,
|
||||
)
|
||||
|
||||
in_d = 1
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for dim, k, stride in conv_layers:
|
||||
self.conv_layers.append(block(in_d, dim, k, stride))
|
||||
in_d = dim
|
||||
|
||||
self.log_compression = log_compression
|
||||
self.skip_connections = skip_connections
|
||||
self.residual_scale = math.sqrt(residual_scale)
|
||||
|
||||
def forward(self, x):
|
||||
# BxT -> BxCxT
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
for conv in self.conv_layers:
|
||||
residual = x
|
||||
x = conv(x)
|
||||
if self.skip_connections and x.size(1) == residual.size(1):
|
||||
tsz = x.size(2)
|
||||
r_tsz = residual.size(2)
|
||||
residual = residual[..., :: r_tsz // tsz][..., :tsz]
|
||||
x = (x + residual) * self.residual_scale
|
||||
|
||||
if self.log_compression:
|
||||
x = x.abs()
|
||||
x = x + 1
|
||||
x = x.log()
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ZeroPad1d(nn.Module):
|
||||
def __init__(self, pad_left, pad_right):
|
||||
super().__init__()
|
||||
self.pad_left = pad_left
|
||||
self.pad_right = pad_right
|
||||
|
||||
def forward(self, x):
|
||||
return F.pad(x, (self.pad_left, self.pad_right))
|
||||
|
||||
|
||||
class ConvAggegator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conv_layers,
|
||||
embed,
|
||||
dropout,
|
||||
skip_connections,
|
||||
residual_scale,
|
||||
non_affine_group_norm,
|
||||
conv_bias,
|
||||
zero_pad,
|
||||
activation,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def block(n_in, n_out, k, stride):
|
||||
# padding dims only really make sense for stride = 1
|
||||
ka = k // 2
|
||||
kb = ka - 1 if k % 2 == 0 else ka
|
||||
|
||||
pad = (
|
||||
ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0))
|
||||
)
|
||||
|
||||
return nn.Sequential(
|
||||
pad,
|
||||
nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias),
|
||||
nn.Dropout(p=dropout),
|
||||
norm_block(False, n_out, affine=not non_affine_group_norm),
|
||||
activation,
|
||||
)
|
||||
|
||||
in_d = embed
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.residual_proj = nn.ModuleList()
|
||||
for dim, k, stride in conv_layers:
|
||||
if in_d != dim and skip_connections:
|
||||
self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False))
|
||||
else:
|
||||
self.residual_proj.append(None)
|
||||
|
||||
self.conv_layers.append(block(in_d, dim, k, stride))
|
||||
in_d = dim
|
||||
self.conv_layers = nn.Sequential(*self.conv_layers)
|
||||
self.skip_connections = skip_connections
|
||||
self.residual_scale = math.sqrt(residual_scale)
|
||||
|
||||
def forward(self, x):
|
||||
for rproj, conv in zip(self.residual_proj, self.conv_layers):
|
||||
residual = x
|
||||
x = conv(x)
|
||||
if self.skip_connections:
|
||||
if rproj is not None:
|
||||
residual = rproj(residual)
|
||||
x = (x + residual) * self.residual_scale
|
||||
return x
|
||||
|
||||
|
||||
class Wav2VecPredictionsModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
prediction_steps,
|
||||
n_negatives,
|
||||
cross_sample_negatives,
|
||||
sample_distance,
|
||||
dropout,
|
||||
offset,
|
||||
balanced_classes,
|
||||
infonce,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_negatives = n_negatives
|
||||
self.cross_sample_negatives = cross_sample_negatives
|
||||
self.sample_distance = sample_distance
|
||||
self.project_to_steps = nn.ConvTranspose2d(
|
||||
in_dim, out_dim, (1, prediction_steps)
|
||||
)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.offset = offset
|
||||
self.balanced_classes = balanced_classes
|
||||
self.infonce = infonce
|
||||
|
||||
def sample_negatives(self, y):
|
||||
bsz, fsz, tsz = y.shape
|
||||
|
||||
y = y.transpose(0, 1) # BCT -> CBT
|
||||
y = y.contiguous().view(fsz, -1) # CBT => C(BxT)
|
||||
|
||||
cross_high = tsz * bsz
|
||||
high = tsz if self.sample_distance is None else min(tsz, self.sample_distance)
|
||||
assert high > 1
|
||||
|
||||
neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz))
|
||||
|
||||
with torch.no_grad():
|
||||
if self.n_negatives > 0:
|
||||
tszs = (
|
||||
buffered_arange(tsz)
|
||||
.unsqueeze(-1)
|
||||
.expand(-1, self.n_negatives)
|
||||
.flatten()
|
||||
)
|
||||
|
||||
neg_idxs = torch.randint(
|
||||
low=0, high=high - 1, size=(bsz, self.n_negatives * tsz)
|
||||
)
|
||||
neg_idxs[neg_idxs >= tszs] += 1
|
||||
|
||||
if self.cross_sample_negatives > 0:
|
||||
tszs = (
|
||||
buffered_arange(tsz)
|
||||
.unsqueeze(-1)
|
||||
.expand(-1, self.cross_sample_negatives)
|
||||
.flatten()
|
||||
)
|
||||
|
||||
cross_neg_idxs = torch.randint(
|
||||
low=0,
|
||||
high=cross_high - 1,
|
||||
size=(bsz, self.cross_sample_negatives * tsz),
|
||||
)
|
||||
cross_neg_idxs[cross_neg_idxs >= tszs] += 1
|
||||
|
||||
if self.n_negatives > 0:
|
||||
for i in range(1, bsz):
|
||||
neg_idxs[i] += i * high
|
||||
else:
|
||||
neg_idxs = cross_neg_idxs
|
||||
|
||||
if self.cross_sample_negatives > 0 and self.n_negatives > 0:
|
||||
neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
|
||||
|
||||
negs = y[..., neg_idxs.view(-1)]
|
||||
negs = negs.view(
|
||||
fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz
|
||||
).permute(
|
||||
2, 1, 0, 3
|
||||
) # to NxBxCxT
|
||||
|
||||
return negs
|
||||
|
||||
def forward(self, x, y):
|
||||
|
||||
x = x.unsqueeze(-1)
|
||||
x = self.project_to_steps(x) # BxCxTxS
|
||||
x = self.dropout(x)
|
||||
|
||||
negatives = self.sample_negatives(y)
|
||||
y = y.unsqueeze(0)
|
||||
targets = torch.cat([y, negatives], dim=0) # Copies x B x C x T
|
||||
|
||||
copies = targets.size(0)
|
||||
bsz, dim, tsz, steps = x.shape
|
||||
steps = min(steps, tsz - self.offset)
|
||||
|
||||
predictions = x.new(
|
||||
bsz * copies * (tsz - self.offset + 1) * steps
|
||||
- ((steps + 1) * steps // 2) * copies * bsz
|
||||
)
|
||||
if self.infonce:
|
||||
labels = predictions.new_full(
|
||||
(predictions.shape[0] // copies,), 0, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
labels = torch.zeros_like(predictions)
|
||||
weights = (
|
||||
torch.full_like(labels, 1 / self.n_negatives)
|
||||
if self.balanced_classes and not self.infonce
|
||||
else None
|
||||
)
|
||||
|
||||
start = end = 0
|
||||
for i in range(steps):
|
||||
offset = i + self.offset
|
||||
end = start + (tsz - offset) * bsz * copies
|
||||
if self.infonce:
|
||||
predictions[start:end] = torch.einsum(
|
||||
"bct,nbct->tbn", x[..., :-offset, i], targets[..., offset:]
|
||||
).flatten()
|
||||
else:
|
||||
pos_num = (end - start) // copies
|
||||
predictions[start:end] = torch.einsum(
|
||||
"bct,nbct->nbt", x[..., :-offset, i], targets[..., offset:]
|
||||
).flatten()
|
||||
labels[start : start + pos_num] = 1.0
|
||||
if weights is not None:
|
||||
weights[start : start + pos_num] = 1.0
|
||||
start = end
|
||||
assert end == predictions.numel(), "{} != {}".format(end, predictions.numel())
|
||||
|
||||
if self.infonce:
|
||||
predictions = predictions.view(-1, copies)
|
||||
else:
|
||||
if weights is not None:
|
||||
labels = (labels, weights)
|
||||
|
||||
return predictions, labels
|
File diff suppressed because it is too large
Load Diff
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
|
||||
|
||||
has_fused_layernorm = True
|
||||
|
||||
class FusedLayerNorm(_FusedLayerNorm):
|
||||
@torch.jit.unused
|
||||
def forward(self, x):
|
||||
if not x.is_cuda:
|
||||
return super().forward(x)
|
||||
else:
|
||||
with torch.cuda.device(x.device):
|
||||
return super().forward(x)
|
||||
|
||||
except ImportError:
|
||||
has_fused_layernorm = False
|
||||
|
||||
|
||||
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
export = True
|
||||
if not export and torch.cuda.is_available() and has_fused_layernorm:
|
||||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
|
||||
class Fp32LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
|
@ -1,842 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fairseq.modules.multihead_attention import MultiheadAttention
|
||||
|
||||
try:
|
||||
from amp_C import multi_tensor_l2norm
|
||||
|
||||
multi_tensor_l2norm_available = True
|
||||
except ImportError:
|
||||
multi_tensor_l2norm_available = False
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
except ImportError:
|
||||
xm = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MANIFOLD_PATH_SEP = "|"
|
||||
|
||||
|
||||
class FileContentsAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
||||
if nargs is not None:
|
||||
raise ValueError("nargs not allowed")
|
||||
super(FileContentsAction, self).__init__(option_strings, dest, **kwargs)
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
from fairseq.file_io import PathManager
|
||||
|
||||
if PathManager.isfile(values):
|
||||
with PathManager.open(values) as f:
|
||||
argument = f.read().strip()
|
||||
else:
|
||||
argument = values
|
||||
setattr(namespace, self.dest, argument)
|
||||
|
||||
|
||||
def split_paths(paths: str, separator=os.pathsep) -> List[str]:
|
||||
return (
|
||||
paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP)
|
||||
)
|
||||
|
||||
|
||||
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
deprecation_warning(
|
||||
"utils.load_ensemble_for_inference is deprecated. "
|
||||
"Please use checkpoint_utils.load_model_ensemble instead."
|
||||
)
|
||||
return checkpoint_utils.load_model_ensemble(
|
||||
filenames, arg_overrides=model_arg_overrides, task=task
|
||||
)
|
||||
|
||||
|
||||
def apply_to_sample(f, sample):
|
||||
if hasattr(sample, "__len__") and len(sample) == 0:
|
||||
return {}
|
||||
|
||||
def _apply(x):
|
||||
if torch.is_tensor(x):
|
||||
return f(x)
|
||||
elif isinstance(x, collections.OrderedDict):
|
||||
# OrderedDict has attributes that needs to be preserved
|
||||
od = collections.OrderedDict(
|
||||
(key, _apply(value)) for key, value in x.items()
|
||||
)
|
||||
od.__dict__ = x.__dict__
|
||||
return od
|
||||
elif isinstance(x, dict):
|
||||
return {key: _apply(value) for key, value in x.items()}
|
||||
elif isinstance(x, list):
|
||||
return [_apply(x) for x in x]
|
||||
elif isinstance(x, tuple):
|
||||
return tuple(_apply(x) for x in x)
|
||||
elif isinstance(x, set):
|
||||
return {_apply(x) for x in x}
|
||||
else:
|
||||
return x
|
||||
|
||||
return _apply(sample)
|
||||
|
||||
|
||||
def move_to_cuda(sample, device=None):
|
||||
device = device or torch.cuda.current_device()
|
||||
|
||||
def _move_to_cuda(tensor):
|
||||
# non_blocking is ignored if tensor is not pinned, so we can always set
|
||||
# to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620)
|
||||
return tensor.to(device=device, non_blocking=True)
|
||||
|
||||
return apply_to_sample(_move_to_cuda, sample)
|
||||
|
||||
|
||||
def move_to_cpu(sample):
|
||||
def _move_to_cpu(tensor):
|
||||
# PyTorch has poor support for half tensors (float16) on CPU.
|
||||
# Move any such tensors to float32.
|
||||
if tensor.dtype in {torch.bfloat16, torch.float16}:
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
return tensor.cpu()
|
||||
|
||||
return apply_to_sample(_move_to_cpu, sample)
|
||||
|
||||
|
||||
def move_to_tpu(sample):
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
def _move_to_tpu(tensor):
|
||||
return tensor.to(device)
|
||||
|
||||
return apply_to_sample(_move_to_tpu, sample)
|
||||
|
||||
|
||||
def get_incremental_state(
|
||||
module: "MultiheadAttention",
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
key: str,
|
||||
) -> Optional[Dict[str, Optional[Tensor]]]:
|
||||
"""Helper for getting incremental state for an nn.Module."""
|
||||
return module.get_incremental_state(incremental_state, key)
|
||||
|
||||
|
||||
def set_incremental_state(
|
||||
module: "MultiheadAttention",
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
key: str,
|
||||
value: Dict[str, Optional[Tensor]],
|
||||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
||||
"""Helper for setting incremental state for an nn.Module."""
|
||||
if incremental_state is not None:
|
||||
result = module.set_incremental_state(incremental_state, key, value)
|
||||
if result is not None:
|
||||
incremental_state = result
|
||||
return incremental_state
|
||||
|
||||
|
||||
def load_align_dict(replace_unk):
|
||||
if replace_unk is None:
|
||||
align_dict = None
|
||||
elif isinstance(replace_unk, str) and len(replace_unk) > 0:
|
||||
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
|
||||
align_dict = {}
|
||||
with open(replace_unk, "r") as f:
|
||||
for line in f:
|
||||
cols = line.split()
|
||||
align_dict[cols[0]] = cols[1]
|
||||
else:
|
||||
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
|
||||
# original source word.
|
||||
align_dict = {}
|
||||
return align_dict
|
||||
|
||||
|
||||
def print_embed_overlap(embed_dict, vocab_dict):
|
||||
embed_keys = set(embed_dict.keys())
|
||||
vocab_keys = set(vocab_dict.symbols)
|
||||
overlap = len(embed_keys & vocab_keys)
|
||||
logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict)))
|
||||
|
||||
|
||||
def parse_embedding(embed_path):
|
||||
"""Parse embedding text file into a dictionary of word and embedding tensors.
|
||||
|
||||
The first line can have vocabulary size and dimension. The following lines
|
||||
should contain word and embedding separated by spaces.
|
||||
|
||||
Example:
|
||||
2 5
|
||||
the -0.0230 -0.0264 0.0287 0.0171 0.1403
|
||||
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
|
||||
"""
|
||||
embed_dict = {}
|
||||
with open(embed_path) as f_embed:
|
||||
next(f_embed) # skip header
|
||||
for line in f_embed:
|
||||
pieces = line.rstrip().split(" ")
|
||||
embed_dict[pieces[0]] = torch.Tensor(
|
||||
[float(weight) for weight in pieces[1:]]
|
||||
)
|
||||
return embed_dict
|
||||
|
||||
|
||||
def load_embedding(embed_dict, vocab, embedding):
|
||||
for idx in range(len(vocab)):
|
||||
token = vocab[idx]
|
||||
if token in embed_dict:
|
||||
embedding.weight.data[idx] = embed_dict[token]
|
||||
return embedding
|
||||
|
||||
|
||||
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
|
||||
from fairseq import tokenizer
|
||||
|
||||
# Tokens are strings here
|
||||
hypo_tokens = tokenizer.tokenize_line(hypo_str)
|
||||
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
|
||||
src_tokens = tokenizer.tokenize_line(src_str) + ["<eos>"]
|
||||
for i, ht in enumerate(hypo_tokens):
|
||||
if ht == unk:
|
||||
src_token = src_tokens[alignment[i]]
|
||||
# Either take the corresponding value in the aligned dictionary or just copy the original value.
|
||||
hypo_tokens[i] = align_dict.get(src_token, src_token)
|
||||
return " ".join(hypo_tokens)
|
||||
|
||||
|
||||
def post_process_prediction(
|
||||
hypo_tokens,
|
||||
src_str,
|
||||
alignment,
|
||||
align_dict,
|
||||
tgt_dict,
|
||||
remove_bpe=None,
|
||||
extra_symbols_to_ignore=None,
|
||||
):
|
||||
hypo_str = tgt_dict.string(
|
||||
hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore
|
||||
)
|
||||
if align_dict is not None:
|
||||
hypo_str = replace_unk(
|
||||
hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string()
|
||||
)
|
||||
if align_dict is not None or remove_bpe is not None:
|
||||
# Convert back to tokens for evaluating with unk replacement or without BPE
|
||||
# Note that the dictionary can be modified inside the method.
|
||||
hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
|
||||
return hypo_tokens, hypo_str, alignment
|
||||
|
||||
|
||||
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
|
||||
"""Replace non-padding symbols with their position numbers.
|
||||
|
||||
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA. In particular XLA
|
||||
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
||||
# how to handle the dtype kwarg in cumsum.
|
||||
mask = tensor.ne(padding_idx).int()
|
||||
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
|
||||
|
||||
|
||||
def strip_pad(tensor, pad):
|
||||
return tensor[tensor.ne(pad)]
|
||||
|
||||
|
||||
def buffered_arange(max):
|
||||
if not hasattr(buffered_arange, "buf"):
|
||||
buffered_arange.buf = torch.LongTensor()
|
||||
if max > buffered_arange.buf.numel():
|
||||
buffered_arange.buf.resize_(max)
|
||||
torch.arange(max, out=buffered_arange.buf)
|
||||
return buffered_arange.buf[:max]
|
||||
|
||||
|
||||
def convert_padding_direction(
|
||||
src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
|
||||
):
|
||||
assert right_to_left ^ left_to_right
|
||||
pad_mask = src_tokens.eq(padding_idx)
|
||||
if not pad_mask.any():
|
||||
# no padding, return early
|
||||
return src_tokens
|
||||
if left_to_right and not pad_mask[:, 0].any():
|
||||
# already right padded
|
||||
return src_tokens
|
||||
if right_to_left and not pad_mask[:, -1].any():
|
||||
# already left padded
|
||||
return src_tokens
|
||||
max_len = src_tokens.size(1)
|
||||
buffered = torch.empty(0).long()
|
||||
if max_len > 0:
|
||||
torch.arange(max_len, out=buffered)
|
||||
range = buffered.type_as(src_tokens).expand_as(src_tokens)
|
||||
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
|
||||
if right_to_left:
|
||||
index = torch.remainder(range - num_pads, max_len)
|
||||
else:
|
||||
index = torch.remainder(range + num_pads, max_len)
|
||||
return src_tokens.gather(1, index)
|
||||
|
||||
|
||||
def item(tensor):
|
||||
# tpu-comment: making this a no-op for xla devices.
|
||||
if torch.is_tensor(tensor) and tensor.device.type == "xla":
|
||||
return tensor.detach()
|
||||
if hasattr(tensor, "item"):
|
||||
return tensor.item()
|
||||
if hasattr(tensor, "__getitem__"):
|
||||
return tensor[0]
|
||||
return tensor
|
||||
|
||||
|
||||
def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor:
|
||||
per_device_grads = {}
|
||||
norms = []
|
||||
for grad in grads:
|
||||
device = grad.device
|
||||
cur_device_grads = per_device_grads.get(device)
|
||||
if cur_device_grads is None:
|
||||
cur_device_grads = []
|
||||
per_device_grads[device] = cur_device_grads
|
||||
cur_device_grads.append(grad)
|
||||
for device in per_device_grads.keys():
|
||||
cur_device_grads = per_device_grads[device]
|
||||
if device.type == "cuda":
|
||||
# TODO(msb) return has_inf
|
||||
has_inf = torch.zeros((1, 1), dtype=torch.int, device=device)
|
||||
with torch.cuda.device(device):
|
||||
norm = multi_tensor_l2norm(
|
||||
chunk_size, has_inf, [cur_device_grads], False
|
||||
)
|
||||
norms.append(norm[0].to(torch.cuda.current_device()))
|
||||
else:
|
||||
norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads]
|
||||
total_norm = torch.norm(torch.stack(norms))
|
||||
return total_norm
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
|
||||
def grad_exists(p):
|
||||
return p is not None and getattr(p, "grad", None) is not None
|
||||
|
||||
if isinstance(params, torch.Tensor):
|
||||
params = [params]
|
||||
params = list(params)
|
||||
grads = [
|
||||
p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert")
|
||||
]
|
||||
expert_grads = [
|
||||
p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert")
|
||||
]
|
||||
|
||||
if len(grads) == 0:
|
||||
if len(params) > 0:
|
||||
return params[0].new_tensor(0.0)
|
||||
else:
|
||||
return torch.tensor(0.0)
|
||||
|
||||
if len(grads) == 1:
|
||||
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
|
||||
else:
|
||||
if multi_tensor_l2norm_available:
|
||||
total_norm = multi_tensor_total_norm(grads)
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
|
||||
"you may get better performance by installing NVIDIA's apex library"
|
||||
)
|
||||
device = torch.cuda.current_device()
|
||||
elif grads[0].device.type == "xla":
|
||||
device = grads[0].device
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
total_norm = torch.norm(
|
||||
torch.stack(
|
||||
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
|
||||
)
|
||||
)
|
||||
|
||||
if aggregate_norm_fn is not None:
|
||||
total_norm = aggregate_norm_fn(total_norm)
|
||||
|
||||
if max_norm > 0:
|
||||
max_norm = float(max_norm)
|
||||
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
||||
for g in grads + expert_grads:
|
||||
g.mul_(clip_coef)
|
||||
return total_norm
|
||||
|
||||
|
||||
def fill_with_neg_inf(t):
|
||||
"""FP16-compatible function that fills a tensor with -inf."""
|
||||
return t.float().fill_(float("-inf")).type_as(t)
|
||||
|
||||
|
||||
def _match_types(arg1, arg2):
|
||||
"""Convert the numerical argument to the same type as the other argument"""
|
||||
|
||||
def upgrade(arg_number, arg_structure):
|
||||
if isinstance(arg_structure, tuple):
|
||||
return tuple([arg_number] * len(arg_structure))
|
||||
elif isinstance(arg_structure, dict):
|
||||
arg = copy.deepcopy(arg_structure)
|
||||
for k in arg:
|
||||
arg[k] = upgrade(arg_number, arg_structure[k])
|
||||
return arg
|
||||
else:
|
||||
return arg_number
|
||||
|
||||
if isinstance(arg1, float) or isinstance(arg1, int):
|
||||
return upgrade(arg1, arg2), arg2
|
||||
elif isinstance(arg2, float) or isinstance(arg2, int):
|
||||
return arg1, upgrade(arg2, arg1)
|
||||
|
||||
return arg1, arg2
|
||||
|
||||
|
||||
def resolve_max_positions(*args):
|
||||
"""Resolve max position constraints from multiple sources."""
|
||||
|
||||
def map_value_update(d1, d2):
|
||||
updated_value = copy.deepcopy(d1)
|
||||
for key in d2:
|
||||
if key not in updated_value:
|
||||
updated_value[key] = d2[key]
|
||||
else:
|
||||
updated_value[key] = min(d1[key], d2[key])
|
||||
return updated_value
|
||||
|
||||
def nullsafe_min(l):
|
||||
minim = None
|
||||
for item in l:
|
||||
if minim is None:
|
||||
minim = item
|
||||
elif item is not None and item < minim:
|
||||
minim = item
|
||||
return minim
|
||||
|
||||
max_positions = None
|
||||
for arg in args:
|
||||
if max_positions is None:
|
||||
max_positions = arg
|
||||
elif arg is not None:
|
||||
max_positions, arg = _match_types(max_positions, arg)
|
||||
if isinstance(arg, float) or isinstance(arg, int):
|
||||
max_positions = min(max_positions, arg)
|
||||
elif isinstance(arg, dict):
|
||||
max_positions = map_value_update(max_positions, arg)
|
||||
else:
|
||||
max_positions = tuple(map(nullsafe_min, zip(max_positions, arg)))
|
||||
|
||||
return max_positions
|
||||
|
||||
|
||||
def import_user_module(args):
|
||||
module_path = getattr(args, "user_dir", None)
|
||||
if module_path is not None:
|
||||
module_path = os.path.abspath(args.user_dir)
|
||||
if not os.path.exists(module_path) and not os.path.isfile(
|
||||
os.path.dirname(module_path)
|
||||
):
|
||||
fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir)
|
||||
if os.path.exists(fairseq_rel_path):
|
||||
module_path = fairseq_rel_path
|
||||
else:
|
||||
fairseq_rel_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", args.user_dir
|
||||
)
|
||||
if os.path.exists(fairseq_rel_path):
|
||||
module_path = fairseq_rel_path
|
||||
else:
|
||||
raise FileNotFoundError(module_path)
|
||||
|
||||
# ensure that user modules are only imported once
|
||||
import_user_module.memo = getattr(import_user_module, "memo", set())
|
||||
if module_path not in import_user_module.memo:
|
||||
import_user_module.memo.add(module_path)
|
||||
|
||||
module_parent, module_name = os.path.split(module_path)
|
||||
if module_name not in sys.modules:
|
||||
sys.path.insert(0, module_parent)
|
||||
importlib.import_module(module_name)
|
||||
|
||||
tasks_path = os.path.join(module_path, "tasks")
|
||||
if os.path.exists(tasks_path):
|
||||
from fairseq.tasks import import_tasks
|
||||
|
||||
import_tasks(tasks_path, f"{module_name}.tasks")
|
||||
|
||||
models_path = os.path.join(module_path, "models")
|
||||
if os.path.exists(models_path):
|
||||
from fairseq.models import import_models
|
||||
|
||||
import_models(models_path, f"{module_name}.models")
|
||||
elif module_path in sys.modules[module_name].__path__:
|
||||
logger.info(f"--user-dir={module_path} has already been imported.")
|
||||
else:
|
||||
raise ImportError(
|
||||
"Failed to import --user-dir={} because the corresponding module name "
|
||||
"({}) is not globally unique. Please rename the directory to "
|
||||
"something unique and try again.".format(module_path, module_name)
|
||||
)
|
||||
|
||||
|
||||
def softmax(x, dim: int, onnx_trace: bool = False):
|
||||
if onnx_trace:
|
||||
return F.softmax(x.float(), dim=dim)
|
||||
else:
|
||||
return F.softmax(x, dim=dim, dtype=torch.float32)
|
||||
|
||||
|
||||
def log_softmax(x, dim: int, onnx_trace: bool = False):
|
||||
if onnx_trace:
|
||||
return F.log_softmax(x.float(), dim=dim)
|
||||
else:
|
||||
return F.log_softmax(x, dim=dim, dtype=torch.float32)
|
||||
|
||||
|
||||
def get_perplexity(loss, round=2, base=2):
|
||||
from fairseq.logging.meters import safe_round
|
||||
|
||||
if loss is None:
|
||||
return 0.0
|
||||
try:
|
||||
return safe_round(base**loss, round)
|
||||
except OverflowError:
|
||||
return float("inf")
|
||||
|
||||
|
||||
def deprecation_warning(message, stacklevel=3):
|
||||
# don't use DeprecationWarning, since it's ignored by default
|
||||
warnings.warn(message, stacklevel=stacklevel)
|
||||
|
||||
|
||||
def relu_squared(x: torch.Tensor):
|
||||
return F.relu(x).pow(2)
|
||||
|
||||
|
||||
def get_activation_fn(activation: str) -> Callable:
|
||||
"""Returns the activation function corresponding to `activation`"""
|
||||
from fairseq.modules import gelu, gelu_accurate
|
||||
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "relu_squared":
|
||||
return relu_squared
|
||||
elif activation == "gelu":
|
||||
return gelu
|
||||
elif activation == "gelu_fast":
|
||||
deprecation_warning(
|
||||
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
||||
)
|
||||
return gelu_accurate
|
||||
elif activation == "gelu_accurate":
|
||||
return gelu_accurate
|
||||
elif activation == "tanh":
|
||||
return torch.tanh
|
||||
elif activation == "linear":
|
||||
return lambda x: x
|
||||
elif activation == "swish":
|
||||
return torch.nn.SiLU
|
||||
else:
|
||||
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||
|
||||
|
||||
def get_available_activation_fns() -> List:
|
||||
return [
|
||||
"relu",
|
||||
"gelu",
|
||||
"gelu_fast", # deprecated
|
||||
"gelu_accurate",
|
||||
"tanh",
|
||||
"linear",
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def model_eval(model):
|
||||
is_training = model.training
|
||||
model.eval()
|
||||
yield
|
||||
model.train(is_training)
|
||||
|
||||
|
||||
def has_parameters(module):
|
||||
try:
|
||||
next(module.parameters())
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
def get_rng_state():
|
||||
state = {"torch_rng_state": torch.get_rng_state()}
|
||||
if xm is not None:
|
||||
state["xla_rng_state"] = xm.get_rng_state()
|
||||
if torch.cuda.is_available():
|
||||
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||
return state
|
||||
|
||||
|
||||
def set_rng_state(state):
|
||||
torch.set_rng_state(state["torch_rng_state"])
|
||||
if xm is not None:
|
||||
xm.set_rng_state(state["xla_rng_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
||||
|
||||
|
||||
class set_torch_seed(object):
|
||||
def __init__(self, seed):
|
||||
assert isinstance(seed, int)
|
||||
self.rng_state = get_rng_state()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if xm is not None:
|
||||
xm.set_rng_state(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
set_rng_state(self.rng_state)
|
||||
|
||||
|
||||
def parse_alignment(line):
|
||||
"""
|
||||
Parses a single line from the alingment file.
|
||||
|
||||
Args:
|
||||
line (str): String containing the alignment of the format:
|
||||
<src_idx_1>-<tgt_idx_1> <src_idx_2>-<tgt_idx_2> ..
|
||||
<src_idx_m>-<tgt_idx_m>. All indices are 0 indexed.
|
||||
|
||||
Returns:
|
||||
torch.IntTensor: packed alignments of shape (2 * m).
|
||||
"""
|
||||
alignments = line.strip().split()
|
||||
parsed_alignment = torch.IntTensor(2 * len(alignments))
|
||||
for idx, alignment in enumerate(alignments):
|
||||
src_idx, tgt_idx = alignment.split("-")
|
||||
parsed_alignment[2 * idx] = int(src_idx)
|
||||
parsed_alignment[2 * idx + 1] = int(tgt_idx)
|
||||
return parsed_alignment
|
||||
|
||||
|
||||
def get_token_to_word_mapping(tokens, exclude_list):
|
||||
n = len(tokens)
|
||||
word_start = [int(token not in exclude_list) for token in tokens]
|
||||
word_idx = list(accumulate(word_start))
|
||||
token_to_word = {i: word_idx[i] for i in range(n)}
|
||||
return token_to_word
|
||||
|
||||
|
||||
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
|
||||
tgt_valid = (
|
||||
((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)
|
||||
)
|
||||
src_invalid = (
|
||||
((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)
|
||||
)
|
||||
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad])
|
||||
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad])
|
||||
alignment = []
|
||||
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent):
|
||||
attn_valid = attn[tgt_valid]
|
||||
attn_valid[:, src_invalid] = float("-inf")
|
||||
_, src_indices = attn_valid.max(dim=1)
|
||||
for tgt_idx, src_idx in zip(tgt_valid, src_indices):
|
||||
alignment.append(
|
||||
(
|
||||
src_token_to_word[src_idx.item()] - 1,
|
||||
tgt_token_to_word[tgt_idx.item()] - 1,
|
||||
)
|
||||
)
|
||||
return alignment
|
||||
|
||||
|
||||
def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos):
|
||||
tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False)
|
||||
src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1)
|
||||
alignment = []
|
||||
if len(tgt_valid) != 0 and len(src_valid) != 0:
|
||||
attn_valid = attn[tgt_valid, src_valid]
|
||||
alignment = [
|
||||
["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid
|
||||
]
|
||||
return alignment
|
||||
|
||||
|
||||
def new_arange(x, *size):
|
||||
"""
|
||||
Return a Tensor of `size` filled with a range function on the device of x.
|
||||
If size is empty, using the size of the variable x.
|
||||
"""
|
||||
if len(size) == 0:
|
||||
size = x.size()
|
||||
return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
|
||||
|
||||
|
||||
def get_tpu_device():
|
||||
return xm.xla_device()
|
||||
|
||||
|
||||
def tpu_data_loader(itr):
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
from fairseq.data import iterators
|
||||
|
||||
xm.rendezvous("tpu_data_loader") # wait for all workers
|
||||
xm.mark_step()
|
||||
device = xm.xla_device()
|
||||
return iterators.CountingIterator(
|
||||
pl.ParallelLoader(itr, [device]).per_device_loader(device),
|
||||
start=getattr(itr, "n", 0),
|
||||
total=len(itr),
|
||||
)
|
||||
|
||||
|
||||
def is_xla_tensor(tensor):
|
||||
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
||||
|
||||
|
||||
def index_put(tensor, indices, value):
|
||||
if is_xla_tensor(tensor):
|
||||
for _ in range(indices.dim(), tensor.dim()):
|
||||
indices = indices.unsqueeze(-1)
|
||||
if indices.size(-1) < tensor.size(-1):
|
||||
indices = indices.expand_as(tensor)
|
||||
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
||||
else:
|
||||
tensor[indices] = value
|
||||
return tensor
|
||||
|
||||
|
||||
def xla_device_to_cpu(dat):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
return xm._maybe_convert_to_cpu(dat)
|
||||
|
||||
|
||||
class CudaEnvironment(object):
|
||||
def __init__(self):
|
||||
cur_device = torch.cuda.current_device()
|
||||
prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device))
|
||||
self.name = prop.name
|
||||
self.major = prop.major
|
||||
self.minor = prop.minor
|
||||
self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024
|
||||
|
||||
@staticmethod
|
||||
def pretty_print_cuda_env_list(cuda_env_list):
|
||||
"""
|
||||
Given a list of CudaEnviorments, pretty print them
|
||||
"""
|
||||
num_workers = len(cuda_env_list)
|
||||
center = "CUDA enviroments for all {} workers".format(num_workers)
|
||||
banner_len = 40 - len(center) // 2
|
||||
first_line = "*" * banner_len + center + "*" * banner_len
|
||||
logger.info(first_line)
|
||||
for r, env in enumerate(cuda_env_list):
|
||||
logger.info(
|
||||
"rank {:3d}: ".format(r)
|
||||
+ "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor)
|
||||
+ "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB)
|
||||
+ "name = {:40s}".format(env.name)
|
||||
)
|
||||
logger.info(first_line)
|
||||
|
||||
|
||||
def csv_str_list(x):
|
||||
return x.split(",")
|
||||
|
||||
|
||||
def eval_str_list(x, type=float):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, str):
|
||||
x = eval(x)
|
||||
try:
|
||||
return list(map(type, x))
|
||||
except TypeError:
|
||||
return [type(x)]
|
||||
|
||||
|
||||
def eval_str_dict(x, type=dict):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, str):
|
||||
x = eval(x)
|
||||
return x
|
||||
|
||||
|
||||
def eval_bool(x, default=False):
|
||||
if x is None:
|
||||
return default
|
||||
try:
|
||||
return bool(eval(x))
|
||||
except TypeError:
|
||||
return default
|
||||
|
||||
|
||||
def reset_logging():
|
||||
root = logging.getLogger()
|
||||
for handler in root.handlers:
|
||||
root.removeHandler(handler)
|
||||
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
root.addHandler(handler)
|
||||
|
||||
|
||||
def safe_getattr(obj, k, default=None):
|
||||
"""Returns obj[k] if it exists and is not None, otherwise returns default."""
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if OmegaConf.is_config(obj):
|
||||
return obj[k] if k in obj and obj[k] is not None else default
|
||||
|
||||
return getattr(obj, k, default)
|
||||
|
||||
|
||||
def safe_hasattr(obj, k):
|
||||
"""Returns True if the given key exists and is not None."""
|
||||
return getattr(obj, k, None) is not None
|
Loading…
Reference in New Issue