Merge pull request #182 from RVC-Boss/patch-1

Update km_train.py
This commit is contained in:
YuriHead 2023-05-18 00:15:50 +08:00 committed by GitHub
commit fd6031b855
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 80 additions and 55 deletions

View File

@ -8,7 +8,7 @@ import argparse
from cluster.kmeans import KMeansGPU
import torch
import numpy as np
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans,MiniBatchKMeans
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -51,5 +51,30 @@ def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=
return x
if __name__ == "__main__":
res=train_cluster("/data/docker/dataset/12b-co256tensor",1000,use_minibatch=False,verbose=False,use_gpu=True)
pdb.set_trace()
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=Path, default="./dataset/44k",
help='path of training data directory')
parser.add_argument('--output', type=Path, default="logs/44k",
help='path of model output directory')
args = parser.parse_args()
checkpoint_dir = args.output
dataset = args.dataset
n_clusters = 1000
ckpt = {}
for spk in os.listdir(dataset):
if os.path.isdir(dataset/spk):
print(f"train kmeans for {spk}...")
in_dir = dataset/spk
x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=True)
ckpt[spk] = x
checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt"
checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(
ckpt,
checkpoint_path,
)