from model import Decoder_ as Decoder  # TODO:

from lightning import LightningModule
import numpy as np
from omegaconf import OmegaConf
from scipy.ndimage import gaussian_filter1d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    LinearLR,
    ReduceLROnPlateau,
    SequentialLR,
)
from torchaudio.functional import edit_distance
from torchmetrics.text import WordErrorRate


class LightningModel_(LightningModule):
    """Wrapper for training and logging."""

    def __init__(self, config):
        super().__init__()

        self.save_hyperparameters()

        self.ce_coeff = config["ce_coeff"]
        self.ctc_coeff = config["ctc_coeff"]
        if self.ce_coeff == 0.0 and self.ctc_coeff == 0.0:
            raise Exception("At least one loss term coefficient has to be >0")

        # load Card's experiment parameters
        dataset_args = OmegaConf.load(
            "/data/data/XXXXXXX/whisper-based/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
        )["dataset"]
        self.transform_args = dataset_args["data_transforms"]

        # retrieve args for each embedder
        embedder_arg_names = [
            "num_features",
            "encodings",
            "day_projections",
            "num_days",
            "r",
        ]
        embedders_args = {}
        for k, v in config.items():
            for arg_name in embedder_arg_names:
                if k.startswith(arg_name):
                    suffix = k[len(arg_name) + 1 :]
                    try:
                        embedders_args[suffix if suffix else ""][arg_name] = v
                    except KeyError:
                        embedders_args[suffix if suffix else ""] = {arg_name: v}

        self.model = Decoder(
            pretrained_whisper_name_or_path=config["whisper_name"],
            embedders_args=embedders_args,
            kernel_size_1=config["kernel_size_1"],
            stride_2=config["stride_2"],
            sessions=dataset_args["sessions"],
            dropout=config["dropout"],
            last_phoneme_layer=config["last_phoneme_layer"],
            attn_window_size=config["attn_window_size"],
            num_classes=41,
            freeze_whisper_decoder=True,
            english_spelling_mapping="english.json",
        )
        self.lr = config["learning_rate"]
        self.lr_min = config["learning_rate_min"]
        self.wd = config["weight_decay"]
        self.scheduler = config["scheduler"]

        if self.ce_coeff > 0.0:
            self.wer = WordErrorRate()
            self.val_transcriptions = []
            self.val_targets = []

        if self.ctc_coeff > 0.0:
            self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=False)
            self.val_total_edit_distance = 0
            self.val_total_seq_length = 0

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = Adam(
            [  # NOTE: these groups of parameters have to be always present
                {
                    "params": self.model.embedders.parameters(),
                    "lr": self.lr,
                    "weight_decay": self.wd,
                },
                {
                    "params": self.model.whisper.model.encoder.parameters(),
                    "lr": self.lr,
                    "weight_decay": self.wd,
                },
            ]
            + (
                [
                    {
                        "params": self.model.whisper.model.decoder.parameters(),  # NOTE: proj_out should be included due to weight tying
                        "lr": 1e-4,  # TODO:
                        "weight_decay": self.wd,
                    }
                ]
                if self.ce_coeff > 0.0
                else []
            )
            + (
                [
                    {
                        "params": self.model.phone_head.parameters(),
                        "lr": self.lr,
                        "weight_decay": self.wd,
                    },
                ]
                if self.ctc_coeff > 0.0
                else []
            )
        )

        if self.scheduler == "cosine":
            scheduler = {
                "scheduler": CosineAnnealingLR(
                    optimizer, T_max=self.trainer.max_epochs, eta_min=self.lr_min
                ),
                "interval": "epoch",
            }
        elif self.scheduler == "linear":
            scheduler = {
                "scheduler": LinearLR(
                    optimizer,
                    start_factor=1.0,
                    end_factor=self.lr_min / self.lr,
                    total_iters=self.trainer.max_epochs,
                ),
                "interval": "epoch",
            }
        elif self.scheduler == "sequential":
            milestone = 100
            assert self.trainer.max_epochs > milestone

            scheduler_1 = CosineAnnealingLR(
                optimizer, T_max=milestone, eta_min=self.lr_min
            )
            scheduler_2 = LinearLR(
                optimizer,
                start_factor=self.lr_min / self.lr,
                end_factor=0.0,
                total_iters=self.trainer.max_epochs - milestone,
            )
            scheduler = SequentialLR(
                optimizer,
                schedulers=[
                    scheduler_1,
                    scheduler_2,
                ],
                milestones=[milestone],
            )
        else:
            raise Exception("Invalid LR scheduler")

        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
        }

    def transform_data(self, features, n_time_steps, mode="train"):
        """Applies various augmentations and smoothing to data.
        From https://github.com/Neuroprosthetics-Lab/nejm-brain-to-text/blob/main/model_training/rnn_trainer.py
        """

        data_shape = features.shape
        batch_size = data_shape[0]
        channels = data_shape[-1]

        # we only apply these augmentations in training
        if mode == "train":
            # add static gain noise
            if self.transform_args["static_gain_std"] > 0:
                warp_mat = torch.tile(
                    torch.unsqueeze(torch.eye(channels), dim=0), (batch_size, 1, 1)
                )
                warp_mat += (
                    torch.randn_like(warp_mat, device=self.device)
                    * self.transform_args["static_gain_std"]
                )

                features = torch.matmul(features, warp_mat)

            # add white noise
            if self.transform_args["white_noise_std"] > 0:
                features += (
                    torch.randn(data_shape, device=self.device)
                    * self.transform_args["white_noise_std"]
                )

            # add constant offset noise
            if self.transform_args["constant_offset_std"] > 0:
                features += (
                    torch.randn((batch_size, 1, channels), device=self.device)
                    * self.transform_args["constant_offset_std"]
                )

            # add random walk noise
            if self.transform_args["random_walk_std"] > 0:
                features += torch.cumsum(
                    torch.randn(data_shape, device=self.device)
                    * self.transform_args["random_walk_std"],
                    dim=self.transform_args["random_walk_axis"],
                )

            # randomly cutoff part of the data timecourse
            if self.transform_args["random_cut"] > 0:
                cut = np.random.randint(0, self.transform_args["random_cut"])
                features = features[:, cut:, :]
                n_time_steps = n_time_steps - cut

        # apply Gaussian smoothing to data
        # NOTE: this is done in both training and validation
        if self.transform_args["smooth_data"]:
            features = gauss_smooth(
                inputs=features,
                device=self.device,
                smooth_kernel_std=self.transform_args["smooth_kernel_std"],
                smooth_kernel_size=self.transform_args["smooth_kernel_size"],
            )

        return features, n_time_steps

    def datasets_to_idxs(self, sources):
        sbj_idx = torch.zeros(len(sources), dtype=torch.long)
        for i, s in enumerate(sources):
            if s == "card":
                sbj_idx[i] = 0
            elif s == "willet":
                sbj_idx[i] = 1
            else:
                raise Exception(f"\"{s}\" is an unrecognized dataset source")

        return sbj_idx

    def training_step(self, train_batch, batch_idx):
        x, x_len = self.transform_data(
            train_batch["neural_feats"], train_batch["neural_time_bins"], mode="train"
        )

        labels = self.model.tokenizer(train_batch["sentence"]).input_ids
        labels = [torch.tensor(l, dtype=torch.long, device=x.device) for l in labels]
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        labels = labels[:, 1:]  # remove SOS IDs

        output = self(
            x,
            x_len,
            train_batch["day"],
            sbj_idx=(
                self.datasets_to_idxs(train_batch["source_dataset"]).to(x.device)
                if "source_dataset" in train_batch
                else None
            ),
            labels=labels,
        )

        loss = 0.0

        # cross entropy loss
        if self.ce_coeff > 0.0:
            ce_loss = output["loss"]
            self.log("train_ce_loss", ce_loss, prog_bar=True)

            loss += self.ce_coeff * ce_loss

        # CTC loss
        if self.ctc_coeff > 0.0:
            ctc_loss = self.ctc_loss(
                log_probs=F.log_softmax(
                    output["phone_logits"].permute(1, 0, 2), dim=-1
                ),
                targets=train_batch["phone_seq"],
                input_lengths=output["x_len"],
                target_lengths=train_batch["phone_seq_len"],
            )
            self.log("train_ctc_loss", ctc_loss, prog_bar=True)

            loss += self.ctc_coeff * ctc_loss

        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        x, x_len = self.transform_data(
            val_batch["neural_feats"], val_batch["neural_time_bins"], mode="val"
        )

        labels = self.model.tokenizer(val_batch["sentence"]).input_ids
        labels = [torch.tensor(l, dtype=torch.long, device=x.device) for l in labels]
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        labels = labels[:, 1:]  # remove SOS IDs

        output = self(
            x,
            x_len,
            val_batch["day"],
            labels=labels,
        )

        loss = 0.0

        if self.ce_coeff > 0.0:
            ce_loss = output["loss"]
            self.log("val_ce_loss", ce_loss, prog_bar=True)

            loss += self.ce_coeff * ce_loss

        if self.ctc_coeff > 0.0:
            ctc_loss = self.ctc_loss(
                log_probs=F.log_softmax(
                    output["phone_logits"].permute(1, 0, 2), dim=-1
                ),
                targets=val_batch["phone_seq"],
                input_lengths=output["x_len"],
                target_lengths=val_batch["phone_seq_len"],
            )
            self.log("val_ctc_loss", ctc_loss, prog_bar=True)

            loss += self.ctc_coeff * ctc_loss

        self.log("val_loss", loss, prog_bar=True)

        if self.ce_coeff > 0.0:
            # store transcriptions for WER computation
            generated_ids = self.model.generate(
                x,
                x_len,
                val_batch["day"],
            )
            transcriptions = self.model.processor.batch_decode(
                generated_ids, skip_special_tokens=True
            )

            self.val_transcriptions += transcriptions
            self.val_targets += val_batch["sentence"]

        if self.ctc_coeff > 0.0:
            # compute batch PER
            # from https://github.com/Neuroprosthetics-Lab/nejm-brain-to-text/blob/main/model_training/rnn_trainer.py
            logits = output["phone_logits"]
            adjusted_lens = output["x_len"]
            labels = val_batch["phone_seq"]
            phone_seq_lens = val_batch["phone_seq_len"]
            batch_edit_distance = 0
            for iterIdx in range(logits.shape[0]):
                decoded_seq = torch.argmax(
                    logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),
                    dim=-1,
                )
                decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
                decoded_seq = decoded_seq.cpu().detach().numpy()
                decoded_seq = np.array([i for i in decoded_seq if i != 0])

                trueSeq = np.array(
                    labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
                )

                batch_edit_distance += edit_distance(decoded_seq, trueSeq)

            self.val_total_edit_distance += batch_edit_distance
            self.val_total_seq_length += torch.sum(phone_seq_lens).item()

    def on_validation_epoch_end(self):
        if self.ce_coeff > 0.0:
            # normalize text
            transcriptions = [self.model.normalizer(s) for s in self.val_transcriptions]
            targets = [self.model.normalizer(s) for s in self.val_targets]

            wer = self.wer(transcriptions, targets)
            self.log("val_wer", wer, prog_bar=True)

            self.val_transcriptions = []
            self.val_targets = []

        if self.ctc_coeff > 0.0:
            per = self.val_total_edit_distance / self.val_total_seq_length
            self.log("val_per", per, prog_bar=True)

            self.val_total_edit_distance = 0
            self.val_total_seq_length = 0
