from collections import deque from functools import partial from inspect import isfunction import numpy as np import torch import torch.nn.functional as F from torch import nn from tqdm import tqdm def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def noise_like(shape, device, repeat=False): def repeat_noise(): return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) def noise(): return torch.randn(shape, device=device) return repeat_noise() if repeat else noise() def linear_beta_schedule(timesteps, max_beta=0.02): """ linear schedule """ betas = np.linspace(1e-4, max_beta, timesteps) return betas def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return np.clip(betas, a_min=0, a_max=0.999) beta_schedule = { "cosine": cosine_beta_schedule, "linear": linear_beta_schedule, } class GaussianDiffusion(nn.Module): def __init__(self, denoise_fn, out_dims=128, timesteps=1000, k_step=1000, max_beta=0.02, spec_min=-12, spec_max=2): super().__init__() self.denoise_fn = denoise_fn self.out_dims = out_dims betas = beta_schedule['linear'](timesteps, max_beta=max_beta) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.k_step = k_step if k_step>0 and k_step 1: if method == 'dpm-solver' or method == 'dpm-solver++': from .dpm_solver_pytorch import ( DPM_Solver, NoiseScheduleVP, model_wrapper, ) # 1. Define the noise schedule. noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) # 2. Convert your discrete-time `model` to the continuous-time # noise prediction model. Here is an example for a diffusion model # `model` with the noise prediction type ("noise") . def my_wrapper(fn): def wrapped(x, t, **kwargs): ret = fn(x, t, **kwargs) if use_tqdm: self.bar.update(1) return ret return wrapped model_fn = model_wrapper( my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", # or "x_start" or "v" or "score" model_kwargs={"cond": cond} ) # 3. Define dpm-solver and sample by singlestep DPM-Solver. # (We recommend singlestep DPM-Solver for unconditional sampling) # You can adjust the `steps` to balance the computation # costs and the sample quality. if method == 'dpm-solver': dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") elif method == 'dpm-solver++': dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") steps = t // infer_speedup if use_tqdm: self.bar = tqdm(desc="sample time step", total=steps) x = dpm_solver.sample( x, steps=steps, order=2, skip_type="time_uniform", method="multistep", ) if use_tqdm: self.bar.close() elif method == 'pndm': self.noise_list = deque(maxlen=4) if use_tqdm: for i in tqdm( reversed(range(0, t, infer_speedup)), desc='sample time step', total=t // infer_speedup, ): x = self.p_sample_plms( x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond ) else: for i in reversed(range(0, t, infer_speedup)): x = self.p_sample_plms( x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond ) elif method == 'ddim': if use_tqdm: for i in tqdm( reversed(range(0, t, infer_speedup)), desc='sample time step', total=t // infer_speedup, ): x = self.p_sample_ddim( x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond ) else: for i in reversed(range(0, t, infer_speedup)): x = self.p_sample_ddim( x, torch.full((b,), i, device=device, dtype=torch.long), infer_speedup, cond=cond ) elif method == 'unipc': from .uni_pc import NoiseScheduleVP, UniPC, model_wrapper # 1. Define the noise schedule. noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas[:t]) # 2. Convert your discrete-time `model` to the continuous-time # noise prediction model. Here is an example for a diffusion model # `model` with the noise prediction type ("noise") . def my_wrapper(fn): def wrapped(x, t, **kwargs): ret = fn(x, t, **kwargs) if use_tqdm: self.bar.update(1) return ret return wrapped model_fn = model_wrapper( my_wrapper(self.denoise_fn), noise_schedule, model_type="noise", # or "x_start" or "v" or "score" model_kwargs={"cond": cond} ) # 3. Define uni_pc and sample by multistep UniPC. # You can adjust the `steps` to balance the computation # costs and the sample quality. uni_pc = UniPC(model_fn, noise_schedule, variant='bh2') steps = t // infer_speedup if use_tqdm: self.bar = tqdm(desc="sample time step", total=steps) x = uni_pc.sample( x, steps=steps, order=2, skip_type="time_uniform", method="multistep", ) if use_tqdm: self.bar.close() else: raise NotImplementedError(method) else: if use_tqdm: for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) else: for i in reversed(range(0, t)): x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) x = x.squeeze(1).transpose(1, 2) # [B, T, M] return self.denorm_spec(x) def norm_spec(self, x): return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 def denorm_spec(self, x): return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min