from email.policy import strict
import torch
import os

import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
import numpy as np
from audioldm_train.modules.diffusionmodules.ema import *

from torch.optim.lr_scheduler import LambdaLR
from audioldm_train.modules.diffusionmodules.model import Encoder, Decoder
from audioldm_train.modules.diffusionmodules.distributions import (
    DiagonalGaussianDistribution,
)

import wandb
from audioldm_train.utilities.model_util import instantiate_from_config
import soundfile as sf

from audioldm_train.utilities.model_util import get_vocoder
from audioldm_train.utilities.tools import synth_one_sample
import itertools


class AutoencoderKL(pl.LightningModule):
    def __init__(
        self,
        ddconfig=None,
        lossconfig=None,
        batchsize=None,
        embed_dim=None,
        time_shuffle=1,
        subband=1,
        sampling_rate=16000,
        ckpt_path=None,
        reload_from_ckpt=None,
        ignore_keys=[],
        image_key="fbank",
        colorize_nlabels=None,
        monitor=None,
        base_learning_rate=1e-5,
    ):
        super().__init__()
        self.automatic_optimization = False
        assert (
            "mel_bins" in ddconfig.keys()
        ), "mel_bins is not specified in the Autoencoder config"
        num_mel = ddconfig["mel_bins"]
        self.image_key = image_key
        self.sampling_rate = sampling_rate
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)

        self.loss = instantiate_from_config(lossconfig)
        self.subband = int(subband)

        if self.subband > 1:
            print("Use subband decomposition %s" % self.subband)

        assert ddconfig["double_z"]
        self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)

        if self.image_key == "fbank":
            self.vocoder = get_vocoder(None, "cpu", num_mel)
        self.embed_dim = embed_dim
        if colorize_nlabels is not None:
            assert type(colorize_nlabels) == int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        self.learning_rate = float(base_learning_rate)
        print("Initial learning rate %s" % self.learning_rate)

        self.time_shuffle = time_shuffle
        self.reload_from_ckpt = reload_from_ckpt
        self.reloaded = False
        self.mean, self.std = None, None

        self.feature_cache = None
        self.flag_first_run = True
        self.train_step = 0

        self.logger_save_dir = None
        self.logger_exp_name = None
        self.logger_exp_group_name = None

        if not self.reloaded and self.reload_from_ckpt is not None:
            print("--> Reload weight of autoencoder from %s" % self.reload_from_ckpt)
            checkpoint = torch.load(self.reload_from_ckpt)

            load_todo_keys = {}
            pretrained_state_dict = checkpoint["state_dict"]
            current_state_dict = self.state_dict()
            for key in current_state_dict:
                if (
                    key in pretrained_state_dict.keys()
                    and pretrained_state_dict[key].size()
                    == current_state_dict[key].size()
                ):
                    load_todo_keys[key] = pretrained_state_dict[key]
                else:
                    print("Key %s mismatch during loading, seems fine" % key)

            self.load_state_dict(load_todo_keys, strict=False)
            self.reloaded = True
        else:
            print("Train from scratch")

    def get_log_dir(self):
        return os.path.join(
            self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name
        )

    def set_log_dir(self, save_dir, exp_group_name, exp_name):
        self.logger_save_dir = save_dir
        self.logger_exp_name = exp_name
        self.logger_exp_group_name = exp_group_name

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print(f"Restored from {path}")

    def encode(self, x):
        # x = self.time_shuffle_operation(x)
        x = self.freq_split_subband(x)
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        # bs, ch, shuffled_timesteps, fbins = dec.size()
        # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
        dec = self.freq_merge_subband(dec)
        return dec

    def decode_to_waveform(self, dec):
        from audioldm_train.utilities.model_util import vocoder_infer

        if self.image_key == "fbank":
            dec = dec.squeeze(1).permute(0, 2, 1)
            wav_reconstruction = vocoder_infer(dec, self.vocoder)
        elif self.image_key == "stft":
            dec = dec.squeeze(1).permute(0, 2, 1)
            wav_reconstruction = self.wave_decoder(dec)
        return wav_reconstruction

    def visualize_latent(self, input):
        import matplotlib.pyplot as plt

        # for i in range(10):
        #     zero_input = torch.zeros_like(input) - 11.59
        #     zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59

        #     posterior = self.encode(zero_input)
        #     latent = posterior.sample()
        #     avg_latent = torch.mean(latent, dim=1)[0]
        #     plt.imshow(avg_latent.cpu().detach().numpy().T)
        #     plt.savefig("%s.png" % i)
        #     plt.close()

        np.save("input.npy", input.cpu().detach().numpy())
        # zero_input = torch.zeros_like(input) - 11.59
        time_input = input.clone()
        time_input[:, :, :, :32] *= 0
        time_input[:, :, :, :32] -= 11.59

        np.save("time_input.npy", time_input.cpu().detach().numpy())

        posterior = self.encode(time_input)
        latent = posterior.sample()
        np.save("time_latent.npy", latent.cpu().detach().numpy())
        avg_latent = torch.mean(latent, dim=1)
        for i in range(avg_latent.size(0)):
            plt.imshow(avg_latent[i].cpu().detach().numpy().T)
            plt.savefig("freq_%s.png" % i)
            plt.close()

        freq_input = input.clone()
        freq_input[:, :, :512, :] *= 0
        freq_input[:, :, :512, :] -= 11.59

        np.save("freq_input.npy", freq_input.cpu().detach().numpy())

        posterior = self.encode(freq_input)
        latent = posterior.sample()
        np.save("freq_latent.npy", latent.cpu().detach().numpy())
        avg_latent = torch.mean(latent, dim=1)
        for i in range(avg_latent.size(0)):
            plt.imshow(avg_latent[i].cpu().detach().numpy().T)
            plt.savefig("time_%s.png" % i)
            plt.close()

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()

        if self.flag_first_run:
            print("Latent size: ", z.size())
            self.flag_first_run = False

        dec = self.decode(z)

        return dec, posterior

    def get_input(self, batch):
        fname, text, label_indices, waveform, stft, fbank = (
            batch["fname"],
            batch["text"],
            batch["label_vector"],
            batch["waveform"],
            batch["stft"],
            batch["log_mel_spec"],
        )
        # if(self.time_shuffle != 1):
        #     if(fbank.size(1) % self.time_shuffle != 0):
        #         pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
        #         fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))

        ret = {}

        ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
            fbank.unsqueeze(1),
            stft.unsqueeze(1),
            fname,
            waveform.unsqueeze(1),
        )

        return ret

    # def time_shuffle_operation(self, fbank):
    #     if(self.time_shuffle == 1):
    #         return fbank

    #     shuffled_fbank = []
    #     for i in range(self.time_shuffle):
    #         shuffled_fbank.append(fbank[:,:, i::self.time_shuffle,:])
    #     return torch.cat(shuffled_fbank, dim=1)

    # def time_unshuffle_operation(self, shuffled_fbank, bs, timesteps, fbins):
    #     if(self.time_shuffle == 1):
    #         return shuffled_fbank

    #     buffer = torch.zeros((bs, 1, timesteps, fbins)).to(shuffled_fbank.device)
    #     for i in range(self.time_shuffle):
    #         buffer[:,0,i::self.time_shuffle,:] = shuffled_fbank[:,i,:,:]
    #     return buffer

    def freq_split_subband(self, fbank):
        if self.subband == 1 or self.image_key != "stft":
            return fbank

        bs, ch, tstep, fbins = fbank.size()

        assert fbank.size(-1) % self.subband == 0
        assert ch == 1

        return (
            fbank.squeeze(1)
            .reshape(bs, tstep, self.subband, fbins // self.subband)
            .permute(0, 2, 1, 3)
        )

    def freq_merge_subband(self, subband_fbank):
        if self.subband == 1 or self.image_key != "stft":
            return subband_fbank
        assert subband_fbank.size(1) == self.subband  # Channel dimension
        bs, sub_ch, tstep, fbins = subband_fbank.size()
        return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)

    def training_step(self, batch, batch_idx):
        g_opt, d_opt = self.optimizers()
        inputs_dict = self.get_input(batch)
        inputs = inputs_dict[self.image_key]
        waveform = inputs_dict["waveform"]

        if batch_idx % 5000 == 0 and self.local_rank == 0:
            print("Log train image")
            self.log_images(inputs, waveform=waveform)

        reconstructions, posterior = self(inputs)

        if self.image_key == "stft":
            rec_waveform = self.decode_to_waveform(reconstructions)
        else:
            rec_waveform = None

        # train the discriminator
        # If working on waveform, inputs is STFT, reconstructions are the waveform
        # If working on the melspec, inputs is melspec, reconstruction are also mel spec
        discloss, log_dict_disc = self.loss(
            inputs=inputs,
            reconstructions=reconstructions,
            posteriors=posterior,
            waveform=waveform,
            rec_waveform=rec_waveform,
            optimizer_idx=1,
            global_step=self.global_step,
            last_layer=self.get_last_layer(),
            split="train",
        )

        self.log(
            "discloss",
            discloss,
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=True,
        )
        self.log_dict(
            log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
        )
        d_opt.zero_grad()
        self.manual_backward(discloss)
        d_opt.step()

        self.log(
            "train_step",
            self.train_step,
            prog_bar=False,
            logger=False,
            on_step=True,
            on_epoch=False,
        )

        self.log(
            "global_step",
            float(self.global_step),
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=False,
        )

        aeloss, log_dict_ae = self.loss(
            inputs=inputs,
            reconstructions=reconstructions,
            posteriors=posterior,
            waveform=waveform,
            rec_waveform=rec_waveform,
            optimizer_idx=0,
            global_step=self.global_step,
            last_layer=self.get_last_layer(),
            split="train",
        )
        self.log(
            "aeloss",
            aeloss,
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=False,
        )
        self.log(
            "posterior_std",
            torch.mean(posterior.var),
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=False,
        )
        self.log_dict(
            log_dict_ae, prog_bar=True, logger=True, on_step=True, on_epoch=False
        )

        self.train_step += 1
        g_opt.zero_grad()
        self.manual_backward(aeloss)
        g_opt.step()

    def validation_step(self, batch, batch_idx):
        inputs_dict = self.get_input(batch)
        inputs = inputs_dict[self.image_key]
        waveform = inputs_dict["waveform"]

        if batch_idx <= 3:
            print("Log val image")
            self.log_images(inputs, train=False, waveform=waveform)

        reconstructions, posterior = self(inputs)

        if self.image_key == "stft":
            rec_waveform = self.decode_to_waveform(reconstructions)
        else:
            rec_waveform = None

        aeloss, log_dict_ae = self.loss(
            inputs=inputs,
            reconstructions=reconstructions,
            posteriors=posterior,
            waveform=waveform,
            rec_waveform=rec_waveform,
            optimizer_idx=0,
            global_step=self.global_step,
            last_layer=self.get_last_layer(),
            split="val",
        )

        discloss, log_dict_disc = self.loss(
            inputs=inputs,
            reconstructions=reconstructions,
            posteriors=posterior,
            waveform=waveform,
            rec_waveform=rec_waveform,
            optimizer_idx=1,
            global_step=self.global_step,
            last_layer=self.get_last_layer(),
            split="val",
        )

        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        return self.log_dict

    def test_step(self, batch, batch_idx):
        inputs_dict = self.get_input(batch)
        inputs = inputs_dict[self.image_key]
        waveform = inputs_dict["waveform"]
        fnames = inputs_dict["fname"]

        reconstructions, posterior = self(inputs)
        save_path = os.path.join(
            self.get_log_dir(), "autoencoder_result_audiocaps", str(self.global_step)
        )

        if self.image_key == "stft":
            wav_prediction = self.decode_to_waveform(reconstructions)
            wav_original = waveform
            self.save_wave(
                wav_prediction, fnames, os.path.join(save_path, "stft_wav_prediction")
            )
        else:
            wav_vocoder_gt, wav_prediction = synth_one_sample(
                inputs.squeeze(1),
                reconstructions.squeeze(1),
                labels="validation",
                vocoder=self.vocoder,
            )
            self.save_wave(
                wav_vocoder_gt, fnames, os.path.join(save_path, "fbank_vocoder_gt_wave")
            )
            self.save_wave(
                wav_prediction, fnames, os.path.join(save_path, "fbank_wav_prediction")
            )

    def save_wave(self, batch_wav, fname, save_dir):
        os.makedirs(save_dir, exist_ok=True)

        for wav, name in zip(batch_wav, fname):
            name = os.path.basename(name)

            sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)

    def configure_optimizers(self):
        lr = self.learning_rate
        params = (
            list(self.encoder.parameters())
            + list(self.decoder.parameters())
            + list(self.quant_conv.parameters())
            + list(self.post_quant_conv.parameters())
        )

        if self.image_key == "stft":
            params += list(self.wave_decoder.parameters())

        opt_ae = torch.optim.Adam(params, lr=lr, betas=(0.5, 0.9))

        if self.image_key == "fbank":
            disc_params = self.loss.discriminator.parameters()
        elif self.image_key == "stft":
            disc_params = itertools.chain(
                self.loss.msd.parameters(), self.loss.mpd.parameters()
            )

        opt_disc = torch.optim.Adam(disc_params, lr=lr, betas=(0.5, 0.9))
        return [opt_ae, opt_disc], []

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    @torch.no_grad()
    def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
        log = dict()
        x = batch.to(self.device)
        if not only_inputs:
            xrec, posterior = self(x)
            log["samples"] = self.decode(posterior.sample())
            log["reconstructions"] = xrec

        log["inputs"] = x
        wavs = self._log_img(log, train=train, index=0, waveform=waveform)
        return wavs

    def _log_img(self, log, train=True, index=0, waveform=None):
        images_input = self.tensor2numpy(log["inputs"][index, 0]).T
        images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
        images_samples = self.tensor2numpy(log["samples"][index, 0]).T

        if train:
            name = "train"
        else:
            name = "val"

        if self.logger is not None:
            self.logger.log_image(
                "img_%s" % name,
                [images_input, images_reconstruct, images_samples],
                caption=["input", "reconstruct", "samples"],
            )

        inputs, reconstructions, samples = (
            log["inputs"],
            log["reconstructions"],
            log["samples"],
        )

        if self.image_key == "fbank":
            wav_original, wav_prediction = synth_one_sample(
                inputs[index],
                reconstructions[index],
                labels="validation",
                vocoder=self.vocoder,
            )
            wav_original, wav_samples = synth_one_sample(
                inputs[index], samples[index], labels="validation", vocoder=self.vocoder
            )
            wav_original, wav_samples, wav_prediction = (
                wav_original[0],
                wav_samples[0],
                wav_prediction[0],
            )
        elif self.image_key == "stft":
            wav_prediction = (
                self.decode_to_waveform(reconstructions)[index, 0]
                .cpu()
                .detach()
                .numpy()
            )
            wav_samples = (
                self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
            )
            wav_original = waveform[index, 0].cpu().detach().numpy()

        if self.logger is not None:
            self.logger.experiment.log(
                {
                    "original_%s"
                    % name: wandb.Audio(
                        wav_original, caption="original", sample_rate=self.sampling_rate
                    ),
                    "reconstruct_%s"
                    % name: wandb.Audio(
                        wav_prediction,
                        caption="reconstruct",
                        sample_rate=self.sampling_rate,
                    ),
                    "samples_%s"
                    % name: wandb.Audio(
                        wav_samples, caption="samples", sample_rate=self.sampling_rate
                    ),
                }
            )

        return wav_original, wav_prediction, wav_samples

    def tensor2numpy(self, tensor):
        return tensor.cpu().detach().numpy()

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
        return x


class IdentityFirstStage(torch.nn.Module):
    def __init__(self, *args, vq_interface=False, **kwargs):
        self.vq_interface = vq_interface  # TODO: Should be true by default but check to not break older stuff
        super().__init__()

    def encode(self, x, *args, **kwargs):
        return x

    def decode(self, x, *args, **kwargs):
        return x

    def quantize(self, x, *args, **kwargs):
        if self.vq_interface:
            return x, None, [None, None, None]
        return x

    def forward(self, x, *args, **kwargs):
        return x
