Merge pull request #182 from RVC-Boss/patch-1

Update km_train.py
This commit is contained in:
YuriHead 2023-05-18 00:15:50 +08:00 committed by GitHub
commit fd6031b855
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 80 additions and 55 deletions

View File

@ -1,55 +1,80 @@
import time,pdb import time,pdb
import tqdm import tqdm
from time import time as ttime from time import time as ttime
import os import os
from pathlib import Path from pathlib import Path
import logging import logging
import argparse import argparse
from cluster.kmeans import KMeansGPU from cluster.kmeans import KMeansGPU
import torch import torch
import numpy as np import numpy as np
from sklearn.cluster import KMeans from sklearn.cluster import KMeans,MiniBatchKMeans
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from time import time as ttime from time import time as ttime
import pynvml,torch import pynvml,torch
def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉虽然库支持但是也不考虑 def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉虽然库支持但是也不考虑
logger.info(f"Loading features from {in_dir}") logger.info(f"Loading features from {in_dir}")
features = [] features = []
nums = 0 nums = 0
for path in tqdm.tqdm(in_dir.glob("*.soft.pt")): for path in tqdm.tqdm(in_dir.glob("*.soft.pt")):
# for name in os.listdir(in_dir): # for name in os.listdir(in_dir):
# path="%s/%s"%(in_dir,name) # path="%s/%s"%(in_dir,name)
features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T) features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T)
# print(features[-1].shape) # print(features[-1].shape)
features = np.concatenate(features, axis=0) features = np.concatenate(features, axis=0)
print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype) print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype)
features = features.astype(np.float32) features = features.astype(np.float32)
logger.info(f"Clustering features of shape: {features.shape}") logger.info(f"Clustering features of shape: {features.shape}")
t = time.time() t = time.time()
if(use_gpu==False): if(use_gpu==False):
if use_minibatch: if use_minibatch:
kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features)
else: else:
kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features)
else: else:
kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)# kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)#
features=torch.from_numpy(features)#.to(device) features=torch.from_numpy(features)#.to(device)
labels = kmeans.fit_predict(features)# labels = kmeans.fit_predict(features)#
print(time.time()-t, "s") print(time.time()-t, "s")
x = { x = {
"n_features_in_": kmeans.n_features_in_ if use_gpu==False else features.shape[0], "n_features_in_": kmeans.n_features_in_ if use_gpu==False else features.shape[0],
"_n_threads": kmeans._n_threads if use_gpu==False else 4, "_n_threads": kmeans._n_threads if use_gpu==False else 4,
"cluster_centers_": kmeans.cluster_centers_ if use_gpu==False else kmeans.centroids.cpu().numpy(), "cluster_centers_": kmeans.cluster_centers_ if use_gpu==False else kmeans.centroids.cpu().numpy(),
} }
print("end") print("end")
return x return x
if __name__ == "__main__": if __name__ == "__main__":
res=train_cluster("/data/docker/dataset/12b-co256tensor",1000,use_minibatch=False,verbose=False,use_gpu=True) parser = argparse.ArgumentParser()
pdb.set_trace() parser.add_argument('--dataset', type=Path, default="./dataset/44k",
help='path of training data directory')
parser.add_argument('--output', type=Path, default="logs/44k",
help='path of model output directory')
args = parser.parse_args()
checkpoint_dir = args.output
dataset = args.dataset
n_clusters = 1000
ckpt = {}
for spk in os.listdir(dataset):
if os.path.isdir(dataset/spk):
print(f"train kmeans for {spk}...")
in_dir = dataset/spk
x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=True)
ckpt[spk] = x
checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt"
checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(
ckpt,
checkpoint_path,
)