import copy
import functools
import os
import blobfile as bf
import torch as th
import torch.nn.functional as F

import numpy as np
from cm.karras_diffusion_v3 import karras_sample
# from cm.random_util import get_generator
# from torchvision.utils import make_grid, save_image
# import datetime
# import dnnlib
import pickle
import glob
import scipy
import soundfile as sf
from cm.compute_objective_metrics import *

import wandb

@th.no_grad()
def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)

class TrainLoop:
    def __init__(
        self,
        #*,
        model,
        discriminator,
        diffusion,
        data,
        accelerator,
        opt, 
        d_opt,
        resume_epoch=0,
        resume_step=0,
        resume_global_step=0,
        eval_dataloader=None,
        # lr_scheduler,
        # d_lr_scheduler,
        args=None,
    ):
        self.args = args
        self.accelerator = accelerator
        self.model = model
        self.discriminator = discriminator
        self.diffusion = diffusion # KarrasDenoiser
        # self.diffusion = diffusion.to(self.accelerator.device) # KarrasDenoiser
        
        self.train_dataloader = data
        self.eval_dataloader = eval_dataloader
        self.batch_size = args.per_device_train_batch_size
        self.lr = args.lr
        # self.lr_scheduler = lr_scheduler
        # self.d_lr_scheduler = d_lr_scheduler
        # self.ema_rate = (
        #     [args.ema_rate]
        #     if isinstance(args.ema_rate, float)
        #     else [float(x) for x in args.ema_rate.split(",")]
        # )
        
        self.step = 0
        self.global_step = 0
        self.first_epoch = 0
        self.resume_epoch = resume_epoch
        self.resume_step = resume_step
        self.resume_global_step = resume_global_step
        self.global_batch = self.batch_size * self.accelerator.num_processes

        self.x_T = th.randn(*(self.batch_size, 
                              self.args.latent_channels, 
                              self.args.latent_t_size, 
                              self.args.latent_f_size), 
                            device=self.accelerator.device) * self.args.sigma_max

        self.opt = opt
        self.d_opt = d_opt
        # self.ema_params = ema_model.get_param_sets()
        self.first_epoch = self.resume_epoch
        self.step = self.resume_step
        self.global_step = self.resume_global_step

    def run_loop(self):
        while not self.args.lr_anneal_steps or self.step < self.args.lr_anneal_steps:
            batch, cond = next(self.data)
            self.run_step(batch, cond)
            # if self.step % self.args.log_interval == 0:
            #     logger.dumpkvs()
            if self.step % self.args.save_interval == 0:
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.args.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        took_step = self.mp_trainer.optimize(self.opt)
        if took_step:
            self.step += 1
            self._update_ema()
        self._anneal_lr()
        # self.log_step()

    def forward_backward(self, batch, cond):
        raise NotImplementedError

    def _anneal_lr(self):
        if not self.args.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.args.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    # def log_step(self):
    #     logger.logkv("step", self.step + self.resume_step)
    #     logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)

    def sampling(self, model, sampler, cond=None, ctm=None, teacher=False, step=-1, batch_size=-1, rate=0.999, guidance_scale=None):
        if not teacher:
            model.eval()
        if step == -1:
            step = self.args.sampling_steps
        if batch_size == -1:
            batch_size = self.args.per_device_eval_batch_size
        with th.no_grad():
            model_kwargs = {}
            sample = karras_sample(
                diffusion=self.diffusion,
                model=model,
                shape=(batch_size, self.args.latent_channels, self.args.latent_t_size, self.args.latent_f_size),
                steps=step,
                cond=cond[0],
                guidance_scale=guidance_scale,
                model_kwargs=model_kwargs,
                device=self.accelerator.device,
                clip_denoised=False,
                sampler=sampler,
                # generator=self.generator,
                teacher=teacher,
                ctm=ctm if ctm != None else True if self.args.training_mode.lower() == 'ctm' else False,
                x_T=None,
                clip_output=self.args.clip_output,
                sigma_min=self.args.sigma_min,
                sigma_max=self.args.sigma_max,
                train=False,
            )
            # if self.latent_decoder != None:
            # TODO: need to check vae stuff
            unwrapped_vae = self.accelerator.unwrap_model(self.latent_decoder)
            sample_mel = unwrapped_vae.decode_first_stage(sample) # Get mel-spectrogram
            sample_waveform = unwrapped_vae.decode_to_waveform(sample_mel) 
            
            all_wavs = (sample_waveform.cpu().numpy() * 32768).astype("int16")
            all_wavs = all_wavs[:, :int(160000)] # truncate length

            os.makedirs(os.path.join(self.args.generated_path, str(rate), str(self.global_step), str(guidance_scale), str(step)), exist_ok=True)
            for wavs, cond_path in zip(all_wavs, cond[1]):
                name = os.path.splitext(os.path.basename(cond_path))[0]

            # for wavs, name in zip(all_wavs, os.path.splitext(os.path.basename(cond[1]))[0]):
                sf.write(os.path.join(self.args.generated_path, str(rate), str(self.global_step), str(guidance_scale), str(step), f'{name}.wav'), wavs, samplerate=16000)
                
        if not teacher:
            model.train()

class CMTrainLoop(TrainLoop):
    def __init__(
        self,
        *,
        target_model,
        teacher_model,
        latent_decoder,
        stft,
        ema_scale_fn,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.training_mode = self.args.training_mode
        self.ema_scale_fn = ema_scale_fn
        self.target_model = target_model
        self.teacher_model = teacher_model
        self.latent_decoder = latent_decoder
        self.stft = stft
        self.total_training_steps = self.args.total_training_steps
        
        if teacher_model:
            self.teacher_model.requires_grad_(False)
            self.teacher_model.eval()
        # breakpoint() # 23398MiB
        # self.global_step = self.step
        # self.initial_step = copy.deepcopy(self.step)

    def run_loop(self):
        saved = False
        for epoch in range(self.first_epoch, self.args.num_train_epochs):
            for step, batch in enumerate(self.train_dataloader):
                text, audios, _ = batch
                self.run_step(audios, text)
                th.cuda.empty_cache()
                
                if (self.global_step 
                    and self.args.save_interval != -1 
                    and self.global_step % self.args.save_interval == 0
                    ):
                    self.accelerator.wait_for_everyone()
                    if self.accelerator.sync_gradients:
                        self.save(epoch)
                        saved = True
                        th.cuda.empty_cache()
                    self.accelerator.wait_for_everyone()

                if self.global_step >= self.args.total_training_steps:
                    self.save(epoch)
                    break
 

    def evaluation(self, model, rate, cond=None, guidance_scale=3.0):
        if self.args.training_mode.lower() == 'ctm':
            if self.args.consistency_weight > 0.:
                self.eval(model, step=1, rate=rate, ctm=True, cond=cond, guidance_scale=guidance_scale)
                if self.args.compute_ema_fads: # False
                    self.eval(model, step=2, rate=rate, ctm=True, cond=cond, guidance_scale=guidance_scale)
                    self.eval(model, step=4, rate=rate, ctm=True, cond=cond, guidance_scale=guidance_scale)
                self.eval(model, step=18, rate=rate, ctm=True, cond=cond, guidance_scale=guidance_scale) 
                # self.eval(step=18, sampler='heun_cfg', teacher=True, ctm=True, rate=rate, cond=cond, guidance_scale=guidance_scale)
            else:
                self.eval(model, step=18, sampler='heun_cfg', teacher=True, ctm=True, rate=rate, cond=cond)
        elif self.args.training_mode.lower() == 'cd':
            self.eval(model, step=1, sampler='onestep', rate=rate, ctm=False, cond=cond)

    def run_step(self, batch, cond):

        assert self.discriminator
        if self.accelerator.is_main_process:
            result = {}
        th.cuda.empty_cache()
        
        estimate, target, x_start, mel, waveform, prompt, t, s = self.get_samples(batch, cond)
        
        if self.global_step >= self.args.discriminator_start_itr:
            if (self.step+1) % self.args.gradient_accumulation_steps != 0:
                with self.accelerator.no_sync(self.discriminator):
                    losses = self.compute_disc_loss(estimate, target, x_start, mel, waveform, prompt)
                    th.cuda.empty_cache()
                    if 'd_loss' in list(losses.keys()):
                        loss = losses["d_loss"].mean()
                        dis_loss = loss.detach().float()
                        # print("disc_loss: {}".format(dis_loss))
                    self.accelerator.backward(loss)
            else:
                losses = self.compute_disc_loss(estimate, target, x_start, mel, waveform, prompt)
                th.cuda.empty_cache()
                if 'd_loss' in list(losses.keys()):
                    loss = losses["d_loss"].mean()
                    dis_loss = loss.detach().float()
                    # print("disc_loss: {}".format(dis_loss))
                self.accelerator.backward(loss)
                th.cuda.empty_cache()
                if self.accelerator.sync_gradients:
                    try:
                        self.accelerator.clip_grad_norm_(self.discriminator.parameters(), self.args.disc_grad_clip_value)
                    except:
                        self.accelerator.clip_grad_norm_(self.discriminator.module.parameters(), self.args.disc_grad_clip_value)
                self.d_opt.step()
                self.d_opt.zero_grad()
                th.cuda.empty_cache()
        else:
            dis_loss = 0.0
            # print("disc_loss: {}".format(dis_loss))
            # self.d_lr_scheduler.step() # NOTE: maybe we can use lr_scheduler like this.

        if (self.step+1) % self.args.gradient_accumulation_steps != 0:
            with self.accelerator.no_sync(self.model):
                losses = self.compute_gen_loss(estimate, target, x_start, mel, waveform, prompt, t, s)
                th.cuda.empty_cache()
                if 'consistency_loss' in list(losses.keys()):
                    loss = self.args.consistency_weight * losses["consistency_loss"].mean()
                    # print("consistency_loss: {}".format(self.args.consistency_weight * losses["consistency_loss"].mean()))
                    
                    if 'denoising_loss' in list(losses.keys()):
                        loss = loss + self.args.denoising_weight * losses['denoising_loss'].mean()
                        # print("dsm_loss: {}".format(self.args.denoising_weight * losses['denoising_loss'].mean()))
            
                    if 'g_loss' in list(losses.keys()):
                        loss = loss + self.args.discriminator_weight * losses['g_loss'].mean()
                        # print("gen_loss: {}".format(self.args.discriminator_weight * losses['g_loss'].mean()))
                            
                    if 'fm_loss' in list(losses.keys()):
                        loss = loss + self.args.fm_weight * losses['fm_loss'].mean()
                        # print("fm_loss: {}".format(self.args.fm_weight * losses['fm_loss'].mean()))
                self.accelerator.backward(loss)

        else:
            losses = self.compute_gen_loss(estimate, target, x_start, mel, waveform, prompt, t, s)
            th.cuda.empty_cache()
            if 'consistency_loss' in list(losses.keys()):
                loss = self.args.consistency_weight * losses["consistency_loss"].mean()
                print("consistency_loss: {}".format(self.args.consistency_weight * losses["consistency_loss"].mean()))
                
                if 'denoising_loss' in list(losses.keys()):
                    loss = loss + self.args.denoising_weight * losses['denoising_loss'].mean()
                    print("dsm_loss: {}".format(self.args.denoising_weight * losses['denoising_loss'].mean()))
        
                if 'g_loss' in list(losses.keys()):
                    loss = loss + self.args.discriminator_weight * losses['g_loss'].mean()
                    # print("gen_loss: {}".format(self.args.discriminator_weight * losses['g_loss'].mean()))
                        
                if 'fm_loss' in list(losses.keys()):
                    loss = loss + self.args.fm_weight * losses['fm_loss'].mean()
                    # print("fm_loss: {}".format(self.args.fm_weight * losses['fm_loss'].mean()))
            self.accelerator.backward(loss)
            if self.accelerator.sync_gradients:
                try:
                    self.accelerator.clip_grad_norm_(self.model.parameters(), self.args.model_grad_clip_value)
                except:
                    self.accelerator.clip_grad_norm_(self.model.module.parameters(), self.args.model_grad_clip_value)
            
            

            self.opt.step()
            self.opt.zero_grad()
            th.cuda.empty_cache()
            # self.lr_scheduler.step() # NOTE: we can use lr_scheduler like this.
            
            if self.accelerator.sync_gradients:
                # self._update_ema()
                if self.target_model: 
                    self._update_target_ema()
                    th.cuda.empty_cache()
                self.global_step += 1
                if self.accelerator.is_main_process:
                    result["step"] = self.step
                    result["global_step"] = self.global_step
                    result["ctm_loss"] = losses["consistency_loss"].mean().detach().float()
                    result["lambda_ctm_loss"] = self.args.consistency_weight * result["ctm_loss"]
                    if 'denoising_loss' in list(losses.keys()):
                        result["dsm_loss"] = losses["denoising_loss"].mean().detach().float()
                        result["lambda_dsm_loss"] = self.args.denoising_weight * result["dsm_loss"]
                    else:
                        result["dsm_loss"] = 0.0
                        result["lambda_dsm_loss"] = 0.0
                    if 'g_loss' in list(losses.keys()):
                        result["gen_loss"] = losses['g_loss'].mean().detach().float()
                        result["lambda_gen_loss"] = self.args.discriminator_weight * result["gen_loss"]
                    else:
                        result["gen_loss"] = 0.0
                        result["lambda_gen_loss"] = 0.0
                    if 'fm_loss' in list(losses.keys()):
                        result["fm_loss"] = losses['fm_loss'].mean().detach().float()
                        result["lambda_fm_loss"] = self.args.fm_weight * losses['fm_loss'].mean().detach().float()
                    else:
                        result["fm_loss"] = 0.0
                        result["lambda_fm_loss"] = 0.0
                    result["disc_loss"] = dis_loss
                    wandb.log(result)
                    self.accelerator.log(result, step=self.global_step)
                self._anneal_lr() # NOTE: we don't use this by following CM&CTM paper.
        # print(self.step)
        self.step += 1


    def _update_target_ema(self):
        target_ema, scales = self.ema_scale_fn(self.global_step)
        with th.no_grad():
            try:
                update_ema(
                    list(self.target_model.ctm_unet.parameters()),
                    list(self.model.ctm_unet.parameters()),
                    rate=target_ema,
                )
            except:
                update_ema(
                    list(self.target_model.ctm_unet.parameters()),
                    list(self.model.module.ctm_unet.parameters()),
                    rate=target_ema,
                )

    
    def get_samples(self, batch, cond):
        estimate, target, x_start, mel, waveform, prompt, t, s = self.diffusion.get_samples(
            step = self.global_step,
            model = self.model, # self.ddp_model
            wavs = batch,
            cond = cond,
            model_kwargs = None,
            target_model = self.target_model,
            teacher_model = self.teacher_model,
            stage1_model = self.latent_decoder,
            stft=self.stft,
            accelerator = self.accelerator,
            noise=None,
            ctm = True if self.training_mode.lower() == 'ctm' else False,
        )

        return estimate, target, x_start, mel, waveform, prompt, t, s
    
    def compute_disc_loss(self, estimate, target, x_start, mel, waveform, prompt):
        losses = self.diffusion.get_disc_loss(
            step = self.global_step,
            model = self.model, # self.ddp_model
            estimate = estimate,
            target = target,
            x_start = x_start,
            mel = mel,
            waveform = waveform,
            prompt = prompt,
            stage1_model = self.latent_decoder,
            accelerator = self.accelerator,
            discriminator = self.discriminator, # self.ddp_discriminator
        )
        return  losses

    def compute_gen_loss(self, estimate, target, x_start, mel, waveform, prompt, t, s):
        losses = self.diffusion.get_gen_loss(
            step = self.global_step,
            model = self.model, # self.ddp_model
            estimate = estimate,
            target = target,
            x_start = x_start,
            mel = mel,
            waveform = waveform,
            prompt = prompt,
            t = t,
            s = s,
            teacher_model = self.teacher_model,
            stage1_model = self.latent_decoder,
            accelerator = self.accelerator,
            discriminator = self.discriminator,
            model_kwargs = None,
        )
        
        return  losses
    
    def forward_backward(self, batch, cond, gen_backword=False):

        compute_losses = functools.partial(
            self.diffusion.ctm_losses,
            step = self.global_step,
            model = self.model, # self.ddp_model
            wavs = batch,
            cond = cond,
            model_kwargs = None,
            target_model = self.target_model,
            teacher_model = self.teacher_model,
            stage1_model = self.latent_decoder,
            accelerator = self.accelerator,
            stft=self.stft,
            discriminator = self.discriminator, # self.ddp_discriminator
            # init_step = self.initial_step,
            ctm = True if self.training_mode.lower() == 'ctm' else False,
        )

        if gen_backword:
        # if self.step % self.args.g_learning_period == 0:
            losses = compute_losses(gen_backword=True)
            # with self.ddp_model.no_sync():
            #     losses = compute_losses()
        else:
            losses = compute_losses(gen_backword=False)
            # with self.ddp_discriminator.no_sync():
            #     losses = compute_losses()
        return losses

    @th.no_grad()
    def eval(self, model, step=1, sampler='exact_cfg', teacher=False, ctm=False, rate=0.999, cond=None, guidance_scale=None):
        # model = self.model
        self.sampling(model=model, cond=cond, sampler=sampler, teacher=teacher, step=step, 
                      batch_size=self.args.per_device_eval_batch_size, rate=rate, ctm=ctm, guidance_scale=guidance_scale)


    def save(self, epoch):
        def save_checkpoint(rate):
            try:
                state_dict = self.target_model.ctm_unet.state_dict()
                # for i, (name, _value) in enumerate(self.model.ctm_unet.named_parameters()):
                #     assert name in state_dict
                #     state_dict[name] = params[i]
            except:
                state_dict = self.target_model.module.ctm_unet.state_dict()
                # for i, (name, _value) in enumerate(self.model.module.ctm_unet.named_parameters()):
                #     assert name in state_dict
                #     state_dict[name] = params[i]
                
            self.accelerator.print(f"saving model {rate}...")
            if not rate:
                filename = f"model{self.global_step:06d}.pt"
            else:
                filename = f"ema_{rate}_{self.global_step:06d}.pt"
            ema_output_dir = os.path.join(self.args.output_dir, f"{self.global_step:06d}", filename)
            os.makedirs(os.path.join(self.args.output_dir, f"{self.global_step:06d}"), exist_ok=True)
            self.accelerator.save(state_dict, ema_output_dir)

        if self.accelerator.is_main_process:
            save_checkpoint(float(self.args.ema_rate))
            self.accelerator.print("saving state...")
            progress_output_dir = os.path.join(self.args.output_dir, f"{self.global_step:06d}", f"progress_state.pth")
            progress_state_dict = {
            'completed_epochs': int(epoch),
            'completed_steps': int(self.step),
            'completed_global_steps': int(self.global_step)
            }
            self.accelerator.save(progress_state_dict, progress_output_dir)
            self.accelerator.save_state("{}/{}".format(self.args.output_dir, f"{self.global_step:06d}")) # define output dir
    

    def d_save(self):
        self.accelerator.info("saving d_optimizer state...")
        if self.accelerator.is_main_process:            
            state_dict = self.discriminator.state_dict()
            for i, (name, _value) in enumerate(self.discriminator.named_parameters()):
                assert name in state_dict
                state_dict[name] = list(self.discriminator.parameters())[i] # self.d_master_params[i]
            d_para_output_dir = os.path.join(self.args.output_dir, f"d_model{self.global_step:06d}.pt")
            self.accelerator.save(state_dict, d_para_output_dir)
            # self.accelerator.save_state("{}/{}".format(self.args.output_dir, f"{self.global_step:06d}")) # define output dir
        # Save model parameters last to prevent race conditions where a restart
        # loads model at step N, but opt/ema state isn't saved for step N.
        # dist.barrier()
        self.accelerator.wait_for_everyone()

    # def log_step(self):
    #     step = self.global_step
    #     self.logger.info("step", step)
    #     self.logger.info("samples", (step + 1) * self.global_batch)


def parse_resume_step_from_filename(filename):
    """
    Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
    checkpoint's number of steps.
    """
    split = filename.split("model")
    if len(split) < 2:
        return 0
    split1 = split[-1].split(".")[0]
    try:
        return int(split1)
    except ValueError:
        return 0


# def get_blob_logdir():
#     # You can change this to be a separate path to save checkpoints to
#     # a blobstore or some external drive.
#     return logger.get_dir()


def find_resume_checkpoint():
    # On your infrastructure, you may want to override this to automatically
    # discover the latest checkpoint on your blob storage, etc.
    return None


def find_ema_checkpoint(main_checkpoint, step, rate):
    if main_checkpoint is None:
        return None
    filename = f"ema_{rate}_{(step):06d}.pt"
    path = bf.join(bf.dirname(main_checkpoint), filename)
    if bf.exists(path):
        return path
    return None


def log_loss_dict(losses, logger):
    for key, values in losses.items():
        logger.info(f"{key} mean", values.mean().item())
        # Log the quantiles (four quartiles, in particular).
        logger.info(f"{key} std", values.std().item())
        #for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
        #    quartile = int(4 * sub_t / diffusion.num_timesteps)
        #    logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
