from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
from transformers.models.modernbert.modeling_modernbert import ModernBertPredictionHead, _pad_modernbert_output, _unpad_modernbert_input
from contextlib import nullcontext

import torch.nn as nn
import torch
import torch.nn.functional as F
from typing import Optional, Union, Tuple, List
from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_attn_mask_utils import (
    _prepare_4d_attention_mask_for_sdpa,
    _prepare_4d_causal_attention_mask_for_sdpa,
)
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

class PianoBertConfig(ModernBertConfig):

    #model_type = "piano-bert"

    def __init__(self, **kwargs):
        total_vocab_size = 5389

        self.mask_token_id = 1
        self.play_token_id = 4
        self.pitch_start = 5
        self.velocity_start = 5 + 128
        self.timing_start = 5 + 128 + 128
        self.pedal_start = 5 + 128 + 128 + 5000

        self.valid_id_range = [
            (5, 133),
            (261, 5261),   
            (133, 261),
            (261, 5261),
            (5261, 5389),
            (5261, 5389),
            (5261, 5389),
            (5261, 5389),
        ]

        super().__init__(
            vocab_size=total_vocab_size,
            pad_token_id=0,
            bos_token_id=2,
            eos_token_id=3,
            **kwargs,
        )
        

class PianoBertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        projecton_size = [110, 108, 110, 110, 110, 110, 110]
        self.projection_layers = nn.ModuleList([nn.Linear(config.hidden_size, projecton_size[i]) for i in range(7)])
        self.hidden_size = config.hidden_size

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
    ) -> torch.Tensor:
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
            grouped_embeds = inputs_embeds.view(batch_size, seq_length // 8, 8, -1)
            first_token_embeds = grouped_embeds[:,:,0,:]
            other_token_embeds = grouped_embeds[:,:,1:,:]
            projection_list = []
            for i in range(7):
                projection_list.append(self.projection_layers[i](other_token_embeds[:,:,i,:]))
            projection_cat = torch.cat(projection_list, dim=-1)
            new_grouped_embeds = torch.stack([first_token_embeds, projection_cat], dim=2)
            inputs_embeds = new_grouped_embeds.view(batch_size, -1, self.hidden_size)
        embeddings = inputs_embeds
        return embeddings

class PianoHead(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.pitch = nn.Linear(config.hidden_size, 128)
        self.interval = nn.Linear(config.hidden_size, 5000)
        self.velocity = nn.Linear(config.hidden_size, 128)
        self.duration = nn.Linear(config.hidden_size, 5000)
        self.pedal1 = nn.Linear(config.hidden_size, 128)
        self.pedal2 = nn.Linear(config.hidden_size, 128)
        self.pedal3 = nn.Linear(config.hidden_size, 128)
        self.pedal4 = nn.Linear(config.hidden_size, 128)
        self.config = config

    def forward(self, sequence_output):

        batch_size = sequence_output.shape[0]
        seq_len = sequence_output.shape[1]
        #prediction_scores = self.cls(sequence_output)
        grouped_output = sequence_output.view(batch_size, seq_len//2, 2, -1)
        first_output = grouped_output[:,:,0,:]
        other_output = grouped_output[:,:,1,:]

        pitch_scores = self.pitch(first_output)
        interval_scores = self.interval(other_output)
        velocity_scores = self.velocity(other_output)
        duration_scores = self.duration(other_output)
        pedal1_scores = self.pedal1(other_output)
        pedal2_scores = self.pedal2(other_output)
        pedal3_scores = self.pedal3(other_output)
        pedal4_scores = self.pedal4(other_output)

        scores_list = [
            pitch_scores,
            interval_scores,
            velocity_scores,
            duration_scores,
            pedal1_scores,
            pedal2_scores,
            pedal3_scores,
            pedal4_scores,
        ]

        prediction_scores = torch.full(
            (batch_size, seq_len // 2, 8, self.config.vocab_size),
            float('-inf'),
            device=pitch_scores.device,
            dtype=pitch_scores.dtype
        )

        for i in range(8):
            prediction_scores[:,:,i,self.config.valid_id_range[i][0]:self.config.valid_id_range[i][1]] = scores_list[i]
        prediction_scores = prediction_scores.view(batch_size, -1, self.config.vocab_size)

        return prediction_scores

class PianoBert(ModernBertPreTrainedModel):
    _tied_weights_keys = ["decoder.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.model = ModernBertModel(config)
        self.head = ModernBertPredictionHead(config)
        #self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
        self.embeddings = PianoBertEmbeddings(config)
        self.decoder = PianoHead(config)

        self.sparse_prediction = self.config.sparse_prediction
        self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
        self.loss_type = "ForMaskedLM"
        # Initialize weights and apply final processing
        self.post_init()

    def get_output_embeddings(self):
        return self.decoder

    def set_output_embeddings(self, new_embeddings: nn.Linear):
        self.decoder = new_embeddings

    @torch.compile(dynamic=True)
    def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.head(output))

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        sliding_window_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        indices: Optional[torch.Tensor] = None,
        cu_seqlens: Optional[torch.Tensor] = None,
        max_seqlen: Optional[int] = None,
        batch_size: Optional[int] = None,
        seq_len: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        self._maybe_set_compile()
        
        inputs_embeds = self.embeddings(input_ids)
        
        if attention_mask is not None:
            B, L = attention_mask.shape
            block_mask = attention_mask.view(B, L // 8, 8)
            mask1 = block_mask[:, :, 0]
            mask2 = block_mask[:, :, 1:].any(dim=-1).long()
            attention_mask = torch.stack([mask1, mask2], dim=2).view(B, -1)

        if self.config._attn_implementation == "flash_attention_2":
            if indices is None and cu_seqlens is None and max_seqlen is None:
                if batch_size is None and seq_len is None:
                    if inputs_embeds is not None:
                        batch_size, seq_len = inputs_embeds.shape[:2]
                    else:
                        batch_size, seq_len = input_ids.shape[:2]
                device = input_ids.device if input_ids is not None else inputs_embeds.device

                if attention_mask is None:
                    attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)

                if inputs_embeds is None:
                    with torch.no_grad():
                        input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input(
                            inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids
                        )
                else:
                    inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input(
                        inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids
                    )
            

        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            sliding_window_mask=sliding_window_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            indices=indices,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            batch_size=batch_size,
            seq_len=seq_len,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        last_hidden_state = outputs[0]

        if self.sparse_prediction and labels is not None:
            # flatten labels and output first
            labels = labels.view(-1)
            last_hidden_state = last_hidden_state.view(labels.shape[0], -1)

            # then filter out the non-masked tokens
            mask_tokens = labels != self.sparse_pred_ignore_index
            last_hidden_state = last_hidden_state[mask_tokens]
            labels = labels[mask_tokens]
        
        if self.config._attn_implementation == "flash_attention_2":
            last_hidden_state = _pad_modernbert_output(inputs=last_hidden_state, indices=indices, batch=batch_size, seqlen=seq_len)
        
        logits = (
            self.compiled_head(last_hidden_state)
            if self.config.reference_compile
            else self.decoder(self.head(last_hidden_state))
        )

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

        #if self.config._attn_implementation == "flash_attention_2":
        #    with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
        #        logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)

        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return MaskedLMOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def _sample(self, logits, sample_strategy="greedy", temperature=1.0, top_k=10, top_p=0.9):
        batch_size, seq_len, vocab_size = logits.shape

        if sample_strategy == "greedy":
            probabilities = F.softmax(logits, dim=-1)
            sampled_tokens = torch.argmax(probabilities, dim=-1)
            confidences = torch.gather(probabilities, -1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
        elif sample_strategy == "sample":
            scaled_logits = logits / temperature
            flat_scaled_logits = scaled_logits.view(-1, vocab_size)

            if top_k > 0:
                top_k_values, top_k_indices = torch.topk(flat_scaled_logits, top_k, dim=-1)
                min_value = torch.finfo(flat_scaled_logits.dtype).min
                mask = torch.full_like(flat_scaled_logits, min_value)
                mask.scatter_(-1, top_k_indices, top_k_values)
                flat_scaled_logits = mask

            if 0.0 < top_p <= 1.0:
                sorted_logits, sorted_indices = torch.sort(flat_scaled_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                if sorted_indices_to_remove.shape[-1] > 1:
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
                flat_scaled_logits[indices_to_remove] = torch.finfo(flat_scaled_logits.dtype).min
            probabilities = F.softmax(flat_scaled_logits, dim=-1)

            sampled_tokens_flat = torch.multinomial(probabilities, num_samples=1).squeeze(-1)
            sampled_tokens = sampled_tokens_flat.view(batch_size, seq_len)

            probabilities_reshaped = probabilities.view(batch_size, seq_len, vocab_size)
            confidences = torch.gather(probabilities_reshaped, -1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
        else:
            raise ValueError(f"Unknown sample_strategy: {sample_strategy}. Must be 'greedy' or 'sample'.")

        return sampled_tokens, confidences

    @torch.no_grad()
    def generate(self, batch_sample, steps, unmask = None, sample_strategy="greedy", temperature=1.0, top_k=10, top_p=0.9, remask_strategy="random"):
        batch_size = len(batch_sample)
        len_list = [len(sample) for sample in batch_sample]
        max_length = max(len_list)

        x = torch.tensor([sample + [self.config.pad_token_id] * (max_length - (len(sample))) for sample in batch_sample]).long().cuda()
        attention_mask = torch.tensor([[1] * len(sample)  + [0] * (max_length - len(sample)) for sample in batch_sample]).long().cuda()
        
        if unmask is None:
            unmask_ind = torch.tensor([[0] * len(sample)  + [1] * (max_length - len(sample)) for sample in batch_sample]).bool()
        else:
            unmask_ind = torch.tensor([unmask[i] + [1] * (max_length - len_list[i]) for i in range(batch_size)]).bool()
        
        x[~unmask_ind] = self.config.mask_token_id
        mask_num = (~unmask_ind).sum(dim=1)

        ts = torch.linspace(1, 0, steps + 1)[1:]
        for t in tqdm(ts):
            logits = self.forward(input_ids=x, attention_mask=attention_mask).logits
            token_ids, confidences = self._sample(logits, sample_strategy, temperature, top_k, top_p)
            x[x == self.config.mask_token_id] = token_ids[x == self.config.mask_token_id]
            if remask_strategy == "random":
                mask_p = torch.ones_like(x) * t
                mask_p[unmask_ind] = 0
                mask_ind = torch.bernoulli(mask_p).bool()
            elif remask_strategy == "ar":
                mask_ind = torch.zeros_like(x).bool()
                mask_ind[:,max_length-int(t * max_length):] = True
                mask_ind[unmask_ind] = False
            elif remask_strategy == "confidence":
                mask_ind = torch.zeros_like(x).bool()
                now_mask_num = (mask_num * t).long()
                for i in range(now_mask_num.shape[0]):
                    _, indices = torch.topk(confidences[i], k=now_mask_num[i].item(), largest=False)
                    mask_ind[i, indices] = True
                mask_ind[unmask_ind] = False

            x[mask_ind] = self.config.mask_token_id
        x = x.cpu().numpy().tolist()
        return [x[i][:len_list[i]] for i in range(len(x))]

def render_wrap(config, batch_sample):
    batch_size = len(batch_sample)
    len_list = [len(sample) * 2 + 8 for sample in batch_sample]
    x = [sample + [config.play_token_id] * 8 + sample for sample in batch_sample]
    unmask = [[1] * (len_list[i] // 2 + 4) + \
                    [1, 0, 0, 0, 0, 0, 0, 0] * ((len_list[i] // 2 - 4) // 8) for i in range(batch_size)]
    #len_list = [len(sample) * 2 for sample in batch_sample]
    #x = [sample + sample for sample in batch_sample]
    #unmask = [[1] * (len_list[i] // 2) + \
    #                [1, 0, 0, 0, 0, 0, 0, 0] * ((len_list[i] // 2) // 8) for i in range(batch_size)]
    return x, unmask

def render_unwrap(batch_sample):
    return [x[-(len(x) // 2 - 4):] for x in batch_sample]
    #return [x[-(len(x) // 2):] for x in batch_sample]

#config = BertConfig(vocab_size=vocab_size, max_position_embeddings=block_size, hidden_size=840)

#model = PianoBert(config)
#model(torch.ones((1, 16), dtype=torch.long)).logits.shape



