
from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip
import pickle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch
import re 
from g2p_en import G2p
import numpy as np
from model.ctc_modelling import LightningGRUDecoder, LightningGRUDecoder_MFCC_v3
import time
import numpy as np
from edit_distance import SequenceMatcher
import tqdm
import pytorch_lightning as pl
import jiwer
import nltk
from nltk.corpus import cmudict
from pytorch_lightning.loggers import WandbLogger
import wandb

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import copy
from difflib import get_close_matches
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer
import pandas as pd
from torchaudio.models.decoder import ctc_decoder
import string
from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE, DATASET_AFTERGO_TRIALS_ZSCORE
import torchaudio  
import torch.nn.functional as F
import matplotlib.pyplot as plt
# from model.ctc_modelling import Light
from dataclasses import dataclass
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
from typing import Optional, Tuple
from torch import nn
from peft import LoraConfig, get_peft_model

normalize_text = Compose([
    ToLowerCase(),
    RemovePunctuation(),
    RemoveMultipleSpaces(),
    Strip()
])




@dataclass
class HybridCausalLMOutput(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """
    ce_loss : Optional[torch.FloatTensor] = None
    ctc_loss : Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

class HybridGRUDecoder(pl.LightningModule):
    def __init__(self, neural_encoder, lm_model_dim, learning_rate=1e-4,weight_decay=1e-5,
                 ce_loss_weight= 1.,
            ctc_loss_weight = 0.,
            l1_loss_weight = 0.1,
            project_from_logits = True, #if False project from neural_embedding
            freeze_lm= True,
            freeze_encoder = False, 
            use_lora=True,
            lora_r=32,
            lora_alpha=64,
            lora_dropout=0.2,
            lora_target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
            bart_name = "facebook/bart-base" # or "facebook/bart-large"
            ):
            
        
        super().__init__()

        self.encoder = neural_encoder
        self.lm_model_dim = lm_model_dim
        self.learning_rate = learning_rate  
        self.weight_decay = weight_decay
        self.ce_loss_weight = ce_loss_weight
        self.ctc_loss_weight = ctc_loss_weight
        self.l1_loss_weight = l1_loss_weight
        self.project_from_logits = project_from_logits

        self.freeze_lm = freeze_lm
        self.freeze_encoder = freeze_encoder
        self.use_lora = use_lora
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.lora_target_modules = lora_target_modules

        
        if self.project_from_logits:
            self.project = nn.Linear(self.encoder.n_classes+1, lm_model_dim)
        else:
            self.project = nn.Sequential(nn.LayerNorm(self.encoder.hidden_dim*2 if self.encoder.bidirectional else self.encoder.hidden_dim),
                                     (nn.Linear(self.encoder.hidden_dim*2 if self.encoder.bidirectional else self.encoder.hidden_dim, lm_model_dim)), 
                                    )

        ## LANGUAGE HEAD
        self.language_model =BartForConditionalGeneration.from_pretrained(bart_name)
        self.tokenizer = BartTokenizer.from_pretrained(bart_name)  

        if self.freeze_lm:
            for name, param in self.language_model.named_parameters():
                if "crossattention" in name or "ln_cross_attn" in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            
            print("Freezing all parameters except cross attention in the language model.")

        # if self.freeze_encoder:
        #     for name, param in self.encoder.named_parameters():
        #         param.requires_grad = False
        #     print("Freezing all parameters in the encoder.")


        for name, param in self.encoder.named_parameters():
            param.requires_grad = not self.freeze_encoder

        if self.freeze_encoder:
            print("Encoder frozen.")
        else:
            print("Encoder will be fine-tuned.")

        if self.use_lora:
            
            lora_cfg = LoraConfig(
                r=self.lora_r,
                lora_alpha=self.lora_alpha,
                lora_dropout=self.lora_dropout,
                target_modules=lora_target_modules,
            )
            self.language_model = get_peft_model(self.language_model, lora_cfg)
            print(f"Using LoRA with r={self.lora_r}, alpha={self.lora_alpha}, dropout={self.lora_dropout} on target modules {self.lora_target_modules}.")

    def apply_lora(self, lora_r=128, lora_alpha=256,lora_dropout = 0.2, lora_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]):
            
        lora_cfg = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=lora_target_modules,
        )
        self.language_model = get_peft_model(self.language_model, lora_cfg)
        print(f"Using LoRA with r={self.lora_r}, alpha={self.lora_alpha}, dropout={self.lora_dropout} on target modules {self.lora_target_modules}.")


    def get_neural_embedding(self, neuralInput, dayIdx):
        """
        Get the neural embedding from the encoder.
        """
        # neuralInput: (batch_size, seq_len, input_dim)
        # dayIdx: (batch_size,)
        neural_embedding = self.encoder.get_neural_embedding(neuralInput, dayIdx)
        return neural_embedding

    def forward(self, neuralInput, dayIdx, sentence = None):
        """
        Forward pass of the model.
        neuralInput: (batch, time, features)
        dayIdx: Session index
        """

        hid = self.get_neural_embedding(neuralInput, dayIdx)

        #1. predict CTC logits
        pred = self.encoder.fc_decoder_out(hid)

        #2. Predict the MFCC
        mfcc_pred = self.encoder.mfcc_decoder(hid)

        #3. Project hid to the language model dimension
        if self.project_from_logits:
            # If projecting from logits, use the predicted logits
            encoder_outputs = self.project(pred)
        else:
            encoder_outputs = self.project(hid)     

        #if sentence is available, proceed with LM part
        if sentence is not None:
            decoder_input_ids = self.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).input_ids
            decoder_input_ids = decoder_input_ids.to(self.device)
            
            lm_outputs = self.language_model(encoder_outputs=(encoder_outputs, encoder_outputs), labels = decoder_input_ids)
        else:
            lm_outputs = None

        return pred, mfcc_pred, lm_outputs
    
    def training_step(self, batch, batch_idx):
        """
        Training step - Runs forward pass, computes loss, and returns it for backprop.
        """

        self.encoder.train()

        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        mfcc_list = []
        for idx in range(len(batch["mfcc"])):
            if self.encoder.clip_mfcc_afer_go:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx][batch["go_onset"][idx]:]))
            else:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx]))



        MFCC = pad_sequence(mfcc_list, batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.encoder.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        # Noise augmentation
        if self.encoder.white_noise_SD > 0:
            X += torch.randn(X.shape, device=self.device) * self.encoder.white_noise_SD
        if self.encoder.constant_offset_SD > 0:
            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.encoder.constant_offset_SD

        # Forward pass
        pred, mfcc_pred, lm_outputs = self.forward(X, dayIdx, sentence = sentence)
        ctc_loss = self.encoder.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.encoder.kernelLen) / self.encoder.strideLen).to(torch.int32),
            y_len,
        )



        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])
        mfcc_pred = mfcc_pred[:, :min_seq_len, :]
        MFCC = MFCC[:, :min_seq_len, :]
        
        l1_loss = self.encoder.l1oss(
            mfcc_pred,
            MFCC,
        )


        lm_loss = lm_outputs.loss 

        loss = self.ce_loss_weight * lm_loss + self.ctc_loss_weight * ctc_loss + self.l1_loss_weight*l1_loss
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log("train_ce_loss", lm_loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log("train_ctc_loss", ctc_loss, prog_bar=True, on_step=True, on_epoch=True)

        return loss
    


    def validation_step(self, batch, batch_idx):
        """
        validation step - Runs forward pass, computes loss, and returns it for backprop.
        """
        X = batch["neural_feats"]
        y = batch["phone_seq"]
        X_len = batch["neural_time_bins"]
        y_len = batch["phone_seq_len"]
        dayIdx = batch["day"]
        sentence = batch["sentence"]
        MFCC = batch["mfcc"]

        #unfold MFCC

        # MFCC = torch.permute(MFCC, (0, 2, 1))

        mfcc_list = []
        for idx in range(len(batch["mfcc"])):
            if self.encoder.clip_mfcc_afer_go:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx][batch["go_onset"][idx]:]))
            else:
                mfcc_list.append(torch.tensor(batch["mfcc"][idx]))



        MFCC = pad_sequence(mfcc_list, batch_first=True)
        MFCC = MFCC.to(self.device)

        MFCC = torch.permute(self.encoder.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))


        # X, y, X_len, y_len, dayIdx, sentence = batch
        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)

        
        # Forward pass
        pred, mfcc_pred, lm_outputs = self.forward(X, dayIdx, sentence = sentence)
        ctc_loss = self.encoder.ctc_loss(
            torch.permute(pred.log_softmax(2), [1, 0, 2]),
            y,
            ((X_len - self.encoder.kernelLen) / self.encoder.strideLen).to(torch.int32),
            y_len,
        )



        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])
        mfcc_pred = mfcc_pred[:, :min_seq_len, :]
        MFCC = MFCC[:, :min_seq_len, :]
        
        l1_loss = self.encoder.l1oss(
            mfcc_pred,
            MFCC,
        )


        lm_loss = lm_outputs.loss 

        loss = self.ce_loss_weight * lm_loss + self.ctc_loss_weight * ctc_loss + self.l1_loss_weight*l1_loss

        # Compute CER (Phoneme Error Rate)
        total_edit_distance, total_seq_length = 0, 0
        for i in range(pred.shape[0]):
            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / self.encoder.strideLen), :], dim=-1)
            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)
            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()

            trueSeq = y[i][:y_len[i]].cpu().numpy()
            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())
            total_edit_distance += matcher.distance()
            total_seq_length += len(trueSeq)

        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0


        generated_text = self.generate(X, dayIdx, max_length=60, num_beams=1)

        if batch_idx==0:
            # neural_embeddings =self.get_neural_embedding(X, dayIdx)
            # encoder_outputs = self.project(neural_embeddings)
            # encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs)

            # generated_ids = self.language_model.generate(encoder_outputs=encoder_outputs, num_beams=5, max_length=60, num_return_sequences=1)
            # generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            print("Generated Text: ", generated_text[:10])
            print("True Text: ", sentence[:10])

        ref_texts = [normalize_text(s) for s in sentence]
        pred_texts = [normalize_text(s) for s in generated_text]

        # Compute WER
        wer_score = wer(ref_texts, pred_texts)

        # Log WER (averaged across all batches by Lightning)
        self.log("val_WER", wer_score, prog_bar=True, on_epoch=True)

        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_ce_loss", lm_loss, prog_bar=True, on_epoch=True)
        self.log("val_ctc_loss",ctc_loss, prog_bar=True, on_epoch=True)
        self.log("val_CER", cer, prog_bar=True, on_epoch=True)
        return loss

    def generate(self, neuralInput, dayIdx, max_length=60, num_beams=1):
        """
        Generate text from the model.
        neuralInput: (batch_size, seq_len, input_dim)
        dayIdx: Session index
        """
        self.eval()
        with torch.no_grad():
            hid = self.get_neural_embedding(neuralInput, dayIdx)
            # encoder_outputs = self.project(hid)

            if self.project_from_logits:
                # If projecting from logits, use the predicted logits
                logits = self.encoder.fc_decoder_out(hid)
                encoder_outputs = self.project(logits)

            else:
                encoder_outputs = self.project(hid)


            encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs)

            generated_ids = self.language_model.generate(encoder_outputs=encoder_outputs, max_length=max_length,num_beams=num_beams, num_return_sequences=1)
            generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        return generated_text


    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.99),
                                      eps=1e-8,)

        return optimizer
    
