
import gc
import typing as tp

import torch
from torch.nn import functional as F
import torchaudio
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from ema_pytorch import EMA
from einops import rearrange
from safetensors.torch import save_file
import wandb

from stable_audio_tools.utils.audio_utils import float_to_int16_audio
from ..models.lm import AudioLanguageModelWrapper
from .scheduler import create_optimizer_from_config, create_scheduler_from_config
from .viz import audio_spectrogram_image


class AudioLanguageModelTrainingWrapper(pl.LightningModule):
    def __init__(
        self,
        model: AudioLanguageModelWrapper,
        lr=1e-4,
        use_ema=False,
        ema_copy=None,
        optimizer_configs: dict = None,
        pre_encoded: bool = False
    ):
        super().__init__()

        self.model = model
        self.model.pretransform.requires_grad_(False)

        self.model_ema = None
        if use_ema:
            self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10)

        assert lr or optimizer_configs, "Must specify either lr or optimizer_configs in training config"

        if optimizer_configs is None:
            optimizer_configs = {
                "lm": {
                    "optimizer": {
                        "type": "AdamW",
                        "config": {
                            "lr": lr,
                            "betas": (0.9, 0.95),
                            "weight_decay": 0.1
                        }
                    }
                }
            }
        else:
            if lr:
                print("WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")

        self.optimizer_configs = optimizer_configs
        self.pre_encoded = pre_encoded

    def configure_optimizers(self):
        lm_opt_config = self.optimizer_configs['lm']
        opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters())

        if "scheduler" in lm_opt_config:
            sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm)
            sched_lm_config = {
                "scheduler": sched_lm,
                "interval": "step"
            }
            return [opt_lm], [sched_lm_config]

        return [opt_lm]

    # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license
    # License can be found in LICENSES/LICENSE_META.txt

    def _compute_cross_entropy(
        self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
    ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
        """Compute cross entropy between multi-codebook targets and model's logits.
        The cross entropy is computed per codebook to provide codebook-level cross entropy.
        Valid timesteps for each of the codebook are pulled from the mask, where invalid
        timesteps are set to 0.

        Args:
            logits (torch.Tensor): Model's logits of shape [B, K, T, card].
            targets (torch.Tensor): Target codes, of shape [B, K, T].
            mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
        Returns:
            ce (torch.Tensor): Cross entropy averaged over the codebooks
            ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
        """
        B, K, T = targets.shape
        assert logits.shape[:-1] == targets.shape
        assert mask.shape == targets.shape
        ce = torch.zeros([], device=targets.device)
        ce_per_codebook: tp.List[torch.Tensor] = []
        for k in range(K):
            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
            ce_targets = targets_k[mask_k]
            ce_logits = logits_k[mask_k]
            q_ce = F.cross_entropy(ce_logits, ce_targets)
            ce += q_ce
            ce_per_codebook.append(q_ce.detach())
        # average cross entropy across codebooks
        ce = ce / K
        return ce, ce_per_codebook

    def training_step(self, batch, batch_idx):
        reals, metadata = batch

        if reals.ndim == 4 and reals.shape[0] == 1:
            reals = reals[0]

        if not self.pre_encoded:
            codes = self.model.pretransform.tokenize(reals)
        else:
            codes = reals

        padding_masks = []
        for md in metadata:
            if md["padding_mask"].ndim == 1:
                padding_masks.append(md["padding_mask"])
            else:
                padding_masks.append(md["padding_mask"][0])

        padding_masks = torch.stack(padding_masks, dim=0).to(self.device)  # (batch_size, sequence_length)

        # Interpolate padding masks to the same length as the codes
        padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool()

        condition_tensors = None

        # If the model is conditioned, get the conditioning tensors
        if self.model.conditioner:
            condition_tensors = self.model.conditioner(metadata, self.device)

        lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1)

        logits = lm_output.logits  # [b, k, t, c]
        logits_mask = lm_output.mask  # [b, k, t]

        logits_mask = logits_mask & padding_masks

        cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask)

        loss = cross_entropy

        log_dict = {
            'train/loss': loss.detach(),
            'train/cross_entropy': cross_entropy.detach(),
            'train/perplexity': torch.exp(cross_entropy).detach(),
            'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
        }

        for k, ce_q in enumerate(cross_entropy_per_codebook):
            log_dict[f'cross_entropy_q{k + 1}'] = ce_q
            log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q)

        self.log_dict(log_dict, prog_bar=True, on_step=True)
        return loss

    def on_before_zero_grad(self, *args, **kwargs):
        if self.model_ema:
            self.model_ema.update()

    def export_model(self, path, use_safetensors=False):

        model = self.model_ema.ema_model if self.model_ema else self.model

        if use_safetensors:
            save_file(model.state_dict(), path)
        else:
            torch.save({"state_dict": model.state_dict()}, path)


class AudioLanguageModelDemoCallback(pl.Callback):
    def __init__(self,
                 demo_every=2000,
                 num_demos=8,
                 sample_size=65536,
                 sample_rate=48000,
                 demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
                 demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
                 **kwargs
                 ):
        super().__init__()

        self.demo_every = demo_every
        self.num_demos = num_demos
        self.demo_samples = sample_size
        self.sample_rate = sample_rate
        self.last_demo_step = -1
        self.demo_conditioning = demo_conditioning
        self.demo_cfg_scales = demo_cfg_scales

    @rank_zero_only
    @torch.no_grad()
    def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx):

        if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
            return

        module.eval()

        print("Generating demo")
        self.last_demo_step = trainer.global_step

        demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio

        # demo_reals = batch[0][:self.num_demos]

        # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
        #     demo_reals = demo_reals[0]

        # demo_reals_tokens = module.model.pretransform.tokenize(demo_reals)

        # Limit to first 50 tokens
        # demo_reals_tokens = demo_reals_tokens[:, :, :50]

        try:
            print("Getting conditioning")

            for cfg_scale in self.demo_cfg_scales:

                model = module.model  # module.model_ema.ema_model if module.model_ema else module.model

                print(f"Generating demo for cfg scale {cfg_scale}")
                fakes = model.generate_audio(
                    batch_size=self.num_demos,
                    max_gen_len=demo_length_tokens,
                    conditioning=self.demo_conditioning,
                    # init_data = demo_reals_tokens,
                    cfg_scale=cfg_scale,
                    temp=1.0,
                    top_p=0.95
                )

                # Put the demos together
                fakes = rearrange(fakes, 'b d n -> d (b n)')

                log_dict = {}

                fakes = float_to_int16_audio(fakes)

                filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
                torchaudio.save(filename, fakes, self.sample_rate)

                log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
                                                                sample_rate=self.sample_rate,
                                                                caption='Reconstructed')

                log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))

                trainer.logger.experiment.log(log_dict)

        except Exception as e:
            raise e
        finally:
            gc.collect()
            torch.cuda.empty_cache()
            module.train()
