import torch
from tqdm import tqdm
from torch import nn
import torch.nn.functional as F
from torchmetrics import F1Score
import pandas as pd
from torch import optim
import pytorch_lightning as pl
import numpy  as np
from sentence_transformers import SentenceTransformer

import sys
sys.path.append('..')

from src.models.core import Classifier_Lighting
from src.models.utils import MultiHeadSelfAttention


def retrieve_from_dict(dict, list_ids):
    return [dict[id.item()] for id in list_ids]



######## Basic Classifier class 
class MHAClassifier(Classifier_Lighting):
    def __init__(self, embed_dim, num_classes, hidden_dim,  max_len, lr, 
                 intermediate=False, num_heads=4, dropout=False, class_weights=[], 
                 temperature_scheduler=None, temperature_step=None, attn_dropout=0.0,
                 path_invert_vocab_sent = "", activation_attention="softmax"):
        super(MHAClassifier, self).__init__()
        self.sent_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
        for name, param in self.sent_model.named_parameters():
            param.requires_grad = False
            
        self.attention = MultiHeadSelfAttention(self.sent_model.get_sentence_embedding_dimension(), embed_dim, num_heads, temperature=1, dropout=attn_dropout, activation_attention=activation_attention)
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        if not intermediate:
            self.fc = nn.Linear(self.embed_dim, num_classes)
        else: 
            self.fc1 = nn.Linear(self.embed_dim, hidden_dim)  # First linear layer
            self.fc2 = nn.Linear(hidden_dim, num_classes)  # Second linear layer

        self.dropout = dropout 
        self.max_len = max_len
        self.lr = lr
        self.intermediate=intermediate   
        if class_weights is not None:
            self.criterion = nn.CrossEntropyLoss(weight=class_weights) 
        else: 
            self.criterion = nn.CrossEntropyLoss()
        self.dropout = dropout
        self.temperature = 1
        self.temperature_scheduler = temperature_scheduler
        self.temperature_step = temperature_step
        self.global_iter = 0

        sent_dict_disk = pd.read_csv(path_invert_vocab_sent+"vocab_sentences.csv")
        self.invert_vocab_sent = {k:v for k,v in zip(sent_dict_disk['Sentence_id'],sent_dict_disk['Sentence'])}        
        self.save_hyperparameters(ignore=["invert_vocab_sent"]) 
    
    def training_step(self, batch, batch_idx):
        return_val = super(MHAClassifier, self).training_step(batch, batch_idx)
        if self.temperature_scheduler=="anneal_decrease":
            # We recommend to set step with a very low value: 1e-3, 1e-4, 1e-5
            self.temperature = max(0.1, np.exp(- self.temperature_step*self.global_iter) )
        self.global_iter+=1

        return return_val

    def on_train_epoch_end(self):
        return_val = super(MHAClassifier, self).on_train_epoch_end()  
        if self.temperature_scheduler=="step_decrease":
            self.temperature = self.temperature - self.temperature_step 
        self.temperature = max(self.temperature, 0.1)
        if self.temperature_scheduler=="step_decrease" or self.temperature_scheduler=="anneal_decrease":
            print ("Temperature at the end of epoch:", self.temperature)
        
        return return_val
        
    ### nuevos para checkpoint -- only works for MHAClassifiers trained with the last changes 
    def on_save_checkpoint(self, checkpoint) -> None:
        """Objects to include in checkpoint file"""
        checkpoint["temperature_end_of_epoch"] = self.temperature

    def on_load_checkpoint(self, checkpoint) -> None:
        """Objects to retrieve from checkpoint file"""
        self.temperature= checkpoint["temperature_end_of_epoch"]


    def forward(self, doc_ids, src_key_padding_mask, matrix_mask):  
        x_emb=[]        
        for doc in doc_ids: 
            source= self.sent_model.encode(retrieve_from_dict(self.invert_vocab_sent, doc[doc!= 0]))
            complement= np.zeros((len(doc[doc== 0]), self.embed_dim))
            temp_emb= np.concatenate((source, complement)) 
            x_emb.append(temp_emb)

        x_emb=torch.tensor(x_emb)

        attn_output, attn_weights = self.attention(x_emb.float().to(self.device), src_key_padding_mask, matrix_mask, temperature=self.temperature)


        if self.dropout!=False:
            attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
        
        attn_output[matrix_mask[:,0,:] != 0] = torch.nan
        attn_output = torch.nanmean(attn_output, dim=1)  # Mean pooling over the sequence length  #### que pasa con el masking de las no sentencias? para calcular loss  -- necesito recuperar el max length del documento 
        
        if not self.intermediate:
            logits = self.fc(attn_output)
        else: 
            attn_output = torch.relu(self.fc1(attn_output))
            logits = self.fc2(attn_output)
        return logits, attn_weights    



######## Extended Classifier class supporting a second MHA layer 
class MHAClassifier_extended(Classifier_Lighting):
    def __init__(self, embed_dim, num_classes, hidden_dim,  max_len, lr, 
                 intermediate=False, num_heads=4, multi_layer=False, class_weights=[], 
                 temperature_scheduler=None, temperature_step=None, attn_dropout=0.0,
                 path_invert_vocab_sent = "", activation_attention="softmax"):
        super(MHAClassifier_extended, self).__init__()
        self.sent_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
        for name, param in self.sent_model.named_parameters():
            param.requires_grad = False
            
        self.attention = MultiHeadSelfAttention(self.sent_model.get_sentence_embedding_dimension(), embed_dim, num_heads, temperature=1, dropout=attn_dropout, activation_attention=activation_attention)
        
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        if intermediate:
            self.pre_linear = nn.Linear(self.embed_dim, hidden_dim)  # First linear layer
            self.multi_attention = MultiHeadSelfAttention(hidden_dim, hidden_dim, num_heads, temperature=1, dropout=attn_dropout, activation_attention=activation_attention)
            
        self.head = nn.Linear(hidden_dim if intermediate else self.embed_dim, num_classes)  
        self.max_len = max_len
        self.lr = lr
        self.intermediate=intermediate   
        if class_weights is not None:
            self.criterion = nn.CrossEntropyLoss(weight=class_weights) 
        else: 
            self.criterion = nn.CrossEntropyLoss()
        self.multi_layer = multi_layer  
        self.temperature = 1
        self.temperature_scheduler = temperature_scheduler
        self.temperature_step = temperature_step
        self.global_iter = 0

        sent_dict_disk = pd.read_csv(path_invert_vocab_sent+"vocab_sentences.csv")
        self.invert_vocab_sent = {k:v for k,v in zip(sent_dict_disk['Sentence_id'],sent_dict_disk['Sentence'])}
        self.save_hyperparameters(ignore=["invert_vocab_sent"]) 
    
    def training_step(self, batch, batch_idx):
        return_val = super(MHAClassifier_extended, self).training_step(batch, batch_idx)
        if self.temperature_scheduler=="anneal_decrease":
            # We recommend to set step with a very low value: 1e-3, 1e-4, 1e-5
            self.temperature = max(0.1, np.exp(- self.temperature_step*self.global_iter) )
        self.global_iter+=1

        return return_val

    def on_train_epoch_end(self):
        return_val = super(MHAClassifier_extended, self).on_train_epoch_end()  
        if self.temperature_scheduler=="step_decrease":
            self.temperature = self.temperature - self.temperature_step 
        self.temperature = max(self.temperature, 0.1)
        if self.temperature_scheduler=="step_decrease" or self.temperature_scheduler=="anneal_decrease":
            print ("Temperature at the end of epoch:", self.temperature)
        
        return return_val
        

    def on_save_checkpoint(self, checkpoint) -> None:
        """Objects to include in checkpoint file"""
        checkpoint["temperature_end_of_epoch"] = self.temperature

    def on_load_checkpoint(self, checkpoint) -> None:
        """Objects to retrieve from checkpoint file"""
        self.temperature= checkpoint["temperature_end_of_epoch"]


    def forward(self, doc_ids, src_key_padding_mask, matrix_mask):  
        x_emb=[]        
        for doc in doc_ids: 
            source= self.sent_model.encode(retrieve_from_dict(self.invert_vocab_sent, doc[doc!= 0]))
            complement= np.zeros((len(doc[doc== 0]), self.embed_dim))
            temp_emb= np.concatenate((source, complement)) 
            x_emb.append(temp_emb)

        x_emb=torch.tensor(x_emb)
        attn_output, attn_weights = self.attention(x_emb.float().to(self.device), src_key_padding_mask, matrix_mask, temperature=self.temperature)

        if self.intermediate:
            attn_output = torch.relu(self.pre_linear(attn_output))

        if self.multi_layer:
            attn_output, attn_weights = self.multi_attention(attn_output.float().to(self.device), src_key_padding_mask, matrix_mask, temperature=self.temperature)

        attn_output[matrix_mask[:,0,:] != 0] = torch.nan
        attn_output = torch.nanmean(attn_output, dim=1)  # Mean pooling over the sequence length 
        
        logits = self.head(attn_output)
        return logits, attn_weights    
