from datetime import datetime
import json
import math
import numbers
import numpy as np
from omegaconf import OmegaConf
from scipy.ndimage import gaussian_filter1d
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoProcessor,
    WhisperForConditionalGeneration,
    WhisperModel,
    WhisperPreTrainedModel,
    WhisperTokenizer,
)
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.whisper.english_normalizer import (
    EnglishTextNormalizer,
    BasicTextNormalizer,
)
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
from transformers.models.whisper.modeling_whisper import (
    WhisperEncoder,
    WhisperEncoderLayer,
    WhisperDecoder,
)


def build_window_mask(seq_len, window_size, device="cpu"):
    """Returns a non-causal windowed attention mask."""

    assert window_size % 2 == 1

    # build a distance matrix
    idxs = torch.arange(seq_len, device=device)
    distance = idxs.unsqueeze(0) - idxs.unsqueeze(1)

    # build the mask
    mask = distance.abs() > window_size // 2  # True where disallowed

    # convert to attention mask format
    attn_mask = torch.zeros(seq_len, seq_len, device=device)
    attn_mask.masked_fill_(mask, float("-inf"))

    return attn_mask


class SessionsToDays(nn.Module):
    """Maps a session index to the corresponding day number (number of days from the first session)."""

    def __init__(self, sessions):
        super().__init__()

        sessions = sorted(sessions)

        fmt = "t15.%Y.%m.%d"
        datetimes = [datetime.strptime(s, fmt) for s in sessions]

        session_to_idx_map = torch.zeros(len(datetimes), dtype=torch.long)
        for i, dt in enumerate(datetimes):
            session_to_idx_map[i] = (dt - datetimes[0]).days
        self.session_to_idx_map = nn.Buffer(session_to_idx_map)

    def forward(self, session_idx):
        return self.session_to_idx_map[session_idx]


class DaysToMonths(nn.Module):
    """Maps all days within a month to the same month index"""

    def __init__(self, days_per_month=30):
        super().__init__()

        self.days_per_month = days_per_month

    def forward(self, days):
        return days // self.days_per_month


class PositionalEncoding(nn.Module):
    """Sinusoidal positional encodings."""

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, input):
        x = self.pe[input]

        return self.dropout(x)


class WhisperEmbedder(nn.Module):
    """Embeds the input neural features for the custom Whisper encoder.
    Replaces convolutions to maintain compatibility with the Hugging Face generation logic.
    """

    def __init__(
        self,
        num_features,
        embed_dim,
        kernel_size_1,
        kernel_size_2,
        stride_2,
        max_source_positions,
        encodings="sin",
        sessions=None,
        day_projections=False,
        num_days=0,
        r=0,  # rank of day-specific projection matrices, 0 to disable low-rank decompositions
    ):
        super().__init__()

        # kernel sizes have to be odd
        assert kernel_size_1 % 2 == 1
        assert kernel_size_1 % 2 == 1

        assert 0 <= r < num_features
        assert r == 0 or day_projections is True

        # initialize the module for mapping session indexes to day numbers
        if encodings == "sin" or r > 0:
            if not sessions:
                raise Exception(
                    "Session list has to be provided when the encodings are sinusoidal"
                )
            self.sessions_to_days = SessionsToDays(sessions)

        if r > 0:
            # initialize the module for mapping days to months
            self.days_to_months = DaysToMonths()

            # month-specific projections
            months = self.days_to_months(
                self.sessions_to_days.session_to_idx_map
            ).unique()  # NOTE: sorted
            self.month_weights = nn.ParameterList(
                [
                    nn.Parameter(torch.eye(num_features)) if i in months else None
                    for i in range(months[-1] + 1)
                ]
            )
            self.month_biases = nn.ParameterList(
                [
                    nn.Parameter(torch.zeros(1, num_features)) if i in months else None
                    for i in range(months[-1] + 1)
                ]
            )

        # day-specific input layers
        # from https://github.com/Neuroprosthetics-Lab/nejm-brain-to-text/blob/main/model_training/rnn_model.py
        if day_projections:
            assert num_days > 0

            self.day_layer_activation = nn.Softsign()

            if r == 0:
                # Card's projections
                self.day_weights = nn.ParameterList(
                    [nn.Parameter(torch.eye(num_features)) for _ in range(num_days)]
                )
            else:
                # low-rank day-specific matrices
                self.day_As = nn.ParameterList(
                    [
                        nn.Parameter(torch.randn(num_features, r))
                        for _ in range(num_days)
                    ]
                )
                self.day_Bs = nn.ParameterList(
                    [
                        nn.Parameter(torch.zeros(r, num_features))
                        for _ in range(num_days)
                    ]
                )

            self.day_biases = nn.ParameterList(
                [nn.Parameter(torch.zeros(1, num_features)) for _ in range(num_days)]
            )

            self.day_layer_dropout = nn.Dropout(0.2)

        self.day_projections = day_projections
        self.r = r

        # updated convolutional embedder
        # NOTE: in the original Whisper encoder, after the convolutions we get 1500 time bins (1 every 20ms) storing information from 65ms (45ms overlap) of the original signal.
        # NOTE: here, after the convolutions we get 1500 time bins (1 every stride*20ms)
        # NOTE: setting stride>1 is like speeding-up attempted speech because Whisper works at 50Hz
        self.conv1 = nn.Conv1d(
            num_features,
            embed_dim,
            kernel_size=kernel_size_1,
            padding=kernel_size_1 // 2,
        )
        self.conv2 = nn.Conv1d(
            embed_dim,
            embed_dim,
            kernel_size=kernel_size_2,
            stride=stride_2,
            padding=kernel_size_2 // 2,
        )

        # day encodings
        if encodings == "sin":
            self.de = PositionalEncoding(embed_dim - embed_dim // 2)
        elif encodings == "learn":
            assert num_days > 0
            self.de = nn.Embedding(
                num_embeddings=num_days, embedding_dim=embed_dim - embed_dim // 2
            )
        else:
            raise Exception('Encodings\' type has to be either "sin" or "learn"')
        self.encodings = encodings

        self.max_source_positions = max_source_positions

    def forward(self, input_features, input_len, day_idx):
        if self.day_projections:
            x = input_features

            if self.r > 0:
                # month-specific weights
                month_weights = torch.stack(
                    [
                        self.month_weights[
                            self.days_to_months(self.sessions_to_days(i))
                        ]
                        for i in day_idx
                    ],
                    dim=0,
                )
                month_biases = torch.cat(
                    [
                        self.month_biases[self.days_to_months(self.sessions_to_days(i))]
                        for i in day_idx
                    ],
                    dim=0,
                ).unsqueeze(1)

            # apply day-specific layer
            day_weights = torch.stack(
                [
                    (
                        self.day_weights[i]
                        if self.r == 0
                        # day-specific low-rank "delta" weights
                        else self.day_As[i] @ self.day_Bs[i]
                    )
                    for i in day_idx
                ],
                dim=0,
            )
            day_biases = torch.cat(
                [self.day_biases[i] for i in day_idx], dim=0
            ).unsqueeze(1)

            weights = day_weights if self.r == 0 else month_weights + day_weights

            x = torch.einsum("btd,bdk->btk", x, weights) + (
                day_biases if self.r == 0 else month_biases + day_biases
            )
            x = self.day_layer_activation(x)

            # apply dropout to the ouput of the day specific layer
            input_features = self.day_layer_dropout(x)

        input_features = input_features.permute(
            0, 2, 1
        )  # shape: (batch_size, num_features, time)

        # pad the last dimension of the input features to the length expected by Whisper with 0s
        expected_seq_length = self.max_source_positions * self.conv2.stride[0]
        pad_size = expected_seq_length - input_features.shape[-1]
        if pad_size > 0:
            input_features = F.pad(
                input_features, (0, pad_size), mode="constant", value=0
            )

        # embed
        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)).permute(0, 2, 1)

        # update the input lengths
        # TODO: check what happens with current padding if stride!=2
        input_len = (
            input_len // self.conv2.stride[0] + input_len % self.conv2.stride[0]
        ).to(torch.long)

        # add day encodings
        de = self.de(
            self.sessions_to_days(day_idx) if self.encodings == "sin" else day_idx
        )
        if de.dim() == 2:
            de = de.unsqueeze(1)
        padded_de = torch.zeros(
            *de.shape[:-1], inputs_embeds.shape[-1], device=de.device
        )
        padded_de[..., -de.shape[-1] :] = de  # leave space for positional encodings
        inputs_embeds = inputs_embeds + padded_de

        return inputs_embeds.permute(0, 2, 1), input_len


class WhisperEncoder_(WhisperEncoder):
    """Whisper encoder with no convolutional embedder."""

    def __init__(self, config, last_phoneme_layer, attn_window_size):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id

        self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim // 2)
        self.embed_positions.requires_grad_(False)

        self.last_phoneme_layer = last_phoneme_layer
        self.attn_window_size = attn_window_size

        self.layers = nn.ModuleList(
            [WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False

        # initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_features,
        attention_mask=None,  # NOTE: overwritten
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        inputs_embeds = input_features.permute(0, 2, 1)
        all_positions = torch.arange(
            self.embed_positions.num_embeddings, device=inputs_embeds.device
        )

        # add positional encodings (for time)
        pe = self.embed_positions(all_positions)
        padded_pe = torch.zeros(
            *pe.shape[:-1], inputs_embeds.shape[-1], device=pe.device
        )
        padded_pe[..., : pe.shape[-1]] = pe  # leave space for custom encodings
        hidden_states = inputs_embeds + padded_pe
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            assert head_mask.size()[0] == (
                len(self.layers)
            ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."

        # build a non-causal windowed attention mask
        if self.attn_window_size > 0:
            attention_mask = build_window_mask(
                seq_len=self.max_source_positions,
                window_size=self.attn_window_size,
                device=hidden_states.device,
            ).expand(hidden_states.shape[0], 1, -1, -1)

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            if to_drop:
                layer_outputs = (None, None)
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    (
                        attention_mask
                        if self.attn_window_size > 0 and idx <= self.last_phoneme_layer
                        else None
                    ),
                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                    output_attentions=output_attentions,
                )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        hidden_states = self.layer_norm(hidden_states)
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, encoder_states, all_attentions]
                if v is not None
            )
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
            attentions=all_attentions,
        )


class WhisperModel_(WhisperModel):
    """WhisperModel updated to accommodate the custom encoder."""

    def __init__(self, config, last_phoneme_layer, attn_window_size):
        WhisperPreTrainedModel.__init__(self, config)

        self.encoder = WhisperEncoder_(config, last_phoneme_layer, attn_window_size)
        self.decoder = WhisperDecoder(config)

        # initialize weights and apply final processing
        self.post_init()


class WhisperForConditionalGeneration_(WhisperForConditionalGeneration):
    """WhisperForConditionalGeneration updated to accommodate the custom Whisper model."""

    def __init__(self, config, dropout=0.4, last_phoneme_layer=1, attn_window_size=0):
        # set dropout probabilities
        config.dropout = dropout
        config.activation_dropout = dropout
        config.attention_dropout = dropout

        WhisperGenerationMixin.__init__(self)
        WhisperPreTrainedModel.__init__(self, config)

        self.model = WhisperModel_(config, last_phoneme_layer, attn_window_size)
        self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.max_target_positions = config.max_target_positions

        # initialize weights and apply final processing
        self.post_init()

    def generate(self, x, **kwargs):
        # build the proper SOS sequence
        # i.e., "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
        decoder_input_ids = torch.tensor(
            [[50257, 50258, 50358, 50362]],
            dtype=torch.long,
            device=x.device,
        ).expand(x.shape[0], -1)

        return GenerationMixin.generate(
            self, inputs=x, decoder_input_ids=decoder_input_ids, **kwargs
        )


class Decoder_(nn.Module):
    """Whisper-based MEA-to-text neural decoder. The model features a "vanilla" Whisper architecture with an updated convolutional embedder."""

    def __init__(
        self,
        pretrained_whisper_name_or_path,
        embedders_args,
        kernel_size_1=7,
        stride_2=2,
        sessions=None,
        dropout=0.4,
        last_phoneme_layer=1,  # id of the last encoder layer dedicated to phoneme representations
        attn_window_size=0,  # window size of the windowed attention used in the early encoder layers, if 0 attention masking is disabled
        num_classes=0,  # number of classes (phonemes) in the dataset, if 0 the phoneme head is disabled (NOTE: 41 for the Brain-to-Text '25 competition)
        freeze_whisper_decoder=True,
        english_spelling_mapping=None,  # path to the Whisper text normalizer spelling_mapping
    ):
        super().__init__()

        # load the pre-trained Whisper
        self.tokenizer = WhisperTokenizer.from_pretrained(
            pretrained_whisper_name_or_path, language="English", task="transcribe"
        )
        self.whisper = WhisperForConditionalGeneration_.from_pretrained(
            pretrained_whisper_name_or_path,
            dropout=dropout,
            last_phoneme_layer=last_phoneme_layer,
            attn_window_size=attn_window_size,
            ignore_mismatched_sizes=True,  # NOTE: encodings have been modified
        )
        self.processor = AutoProcessor.from_pretrained(pretrained_whisper_name_or_path)

        # TODO: remove
        # # re-initialize selected encoder layers
        # for layer_id in range(last_phoneme_layer + 1):
        #     self.whisper._init_weights(self.whisper.model.encoder.layers[layer_id])
        self.last_phoneme_layer = last_phoneme_layer

        self.is_whisper_decoder_frozen = False
        if freeze_whisper_decoder:
            # freeze the Whisper decoder (also changes the flag)
            self.freeze_whisper_decoder()

        # initialize Whisper embedders
        self.embedders = nn.ModuleList(
            [
                WhisperEmbedder(
                    embed_dim=self.whisper.config.d_model,
                    kernel_size_1=kernel_size_1,
                    kernel_size_2=3,
                    stride_2=stride_2,
                    max_source_positions=self.whisper.config.max_source_positions,
                    sessions=sessions,
                    **embedders_args[k],
                )
                for k in sorted(embedders_args.keys())
            ]
        )

        # share convolutional layers across embedders
        nums_features = [v["num_features"] for v in embedders_args.values()]
        share_conv1 = all(n == nums_features[0] for n in nums_features)
        for embedder in self.embedders[1:]:
            if share_conv1:
                embedder.conv1 = self.embedders[0].conv1
            embedder.conv2 = self.embedders[0].conv2
            print("Conv layers shared across embedders")

        if num_classes > 0:
            # compute the number of tokens corresponding to 80ms of the original neural signal
            num_tokens_to_be_concat = 4 // stride_2
            self.num_tokens_to_be_concat = num_tokens_to_be_concat

            self.phone_head = nn.Linear(
                self.num_tokens_to_be_concat * self.whisper.config.d_model,
                num_classes,
            )

        # initialize the Whisper text normalizer
        if english_spelling_mapping:
            with open(english_spelling_mapping, "r") as f:
                english_spelling_mapping = json.load(f)
            self.normalizer = EnglishTextNormalizer(english_spelling_mapping)
        else:
            self.normalizer = BasicTextNormalizer()

    def freeze_whisper_decoder(self):
        # freeze the Whisper decoder
        for param in self.whisper.model.decoder.parameters():
            param.requires_grad = False
        self.is_whisper_decoder_frozen = True

    def unfreeze_whisper_decoder(self):
        # unfreeze the Whisper decoder
        for param in self.whisper.model.decoder.parameters():
            param.requires_grad = True
        self.is_whisper_decoder_frozen = False

    def embed(self, x, x_len, day_idx, sbj_idx=None):
        """Computes (subject-specific) embeddings for the Whisper encoder."""

        if sbj_idx is not None:
            # compute subject-specific embeddings
            x_temp = torch.zeros(
                x.shape[0],
                self.whisper.config.d_model,
                self.whisper.config.max_source_positions,
                device=x.device,
            )
            x_len_temp = torch.zeros_like(x_len)
            for s in sbj_idx.unique():
                x_temp[sbj_idx == s], x_len_temp[sbj_idx == s] = self.embedders[s](
                    x[sbj_idx == s], x_len[sbj_idx == s], day_idx[sbj_idx == s]
                )
            x, x_len = x_temp, x_len_temp
        else:
            # compute embeddings
            x, x_len = self.embedders[0](x, x_len, day_idx)

        return x, x_len

    def forward(self, x, x_len, day_idx, sbj_idx=None, **kwargs):
        x, x_len = self.embed(x, x_len, day_idx, sbj_idx)

        output = self.whisper(x, output_hidden_states=True, **kwargs)

        if self.phone_head:
            x = output["encoder_hidden_states"][
                self.last_phoneme_layer + 1
            ]  # NOTE: the first element contains Whisper's convolutional embeddings

            # one prediction every 80ms
            x = x.unfold(
                dimension=1,
                size=self.num_tokens_to_be_concat,
                step=self.num_tokens_to_be_concat,
            )  # shape: (batch_size, new_time, size, num_features)
            x = x.reshape(*x.shape[:2], -1)

            # update the input lengths
            x_len = (x_len / self.num_tokens_to_be_concat).to(torch.long)

            # predict phonemes
            x = self.phone_head(x)

            output["phone_logits"] = x
            output["x_len"] = x_len

            return output
        else:
            return output

    def generate(self, x, x_len, day_idx, sbj_idx=None, **kwargs):
        x, _ = self.embed(x, x_len, day_idx, sbj_idx)

        return self.whisper.generate(x, **kwargs)
