diff --git a/cluster/train_cluster.py b/cluster/train_cluster.py index dfa55e6..135f179 100644 --- a/cluster/train_cluster.py +++ b/cluster/train_cluster.py @@ -14,7 +14,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑 - if in_dir.endswith(".ipynb_checkpoints"): + if str(in_dir).endswith(".ipynb_checkpoints"): logger.info(f"Ignore {in_dir}") logger.info(f"Loading features from {in_dir}")