import re
import torch
import random
import warnings
import transformers
import torch.nn as nn
from tqdm import tqdm
from .prompts import *
from typing import List
from typing import Callable
from locolms.utils.training import *
from peft import (
    AutoPeftModelForCausalLM, 
    LoraConfig, get_peft_model, 
    prepare_model_for_kbit_training
)
from transformers import (
    AutoConfig, AutoTokenizer, 
    AutoModelForCausalLM, BitsAndBytesConfig
)

class LoCoLM(nn.Module):

    def __init__(self, gpu_id, model_hf_name, quantization: bool = False):
        super(LoCoLM, self).__init__()

        self.model_hf_name = model_hf_name
        self.gpu_id = gpu_id

        # Model
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_hf_name)
        
        if self.is_decoder():
            if quantization:
                print("[-] Loading decoder only model with quantization")
                # Model configuration
                self.model = transformers.AutoModelForCausalLM.from_pretrained(
                    self.model_hf_name,
                    quantization_config=BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_compute_dtype=torch.bfloat16,
                        bnb_4bit_use_double_quant=True,
                        bnb_4bit_quant_type='nf4'
                    ),
                    torch_dtype=torch.bfloat16,
                )
                self.model.config.use_cache = False
                self.model = prepare_model_for_kbit_training(self.model)
                self.tokenizer.pad_token = self.tokenizer.eos_token
                # TODO: to test
                # self.tokenizer.padding_side = "right"
                peft_config = LoraConfig(
                    r=128,
                    lora_alpha=16,
                    target_modules=find_all_linear_names(self.model),
                    lora_dropout=0.05,
                    bias="none",
                    task_type="CAUSAL_LM",
                )
                self.model = get_peft_model(self.model, peft_config)
                print_trainable_parameters(self.model)
            else:
                print("[-] Loading decoder only model")
                self.model = transformers.AutoModelForCausalLM.from_pretrained(
                    self.model_hf_name
                )
                self.model.config.use_cache = False
                self.tokenizer.pad_token = self.tokenizer.eos_token
                # TODO: to test
                # self.tokenizer.padding_side = "right"
                
        elif self.is_seq2seq():
            print("[-] Loading seq2seq model")
            self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
                self.model_hf_name,
                torch_dtype=torch.bfloat16,
            )
        
        else: raise Exception("Invalid model type. It must be either decoder-only or seq2seq.")

        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.model.resize_token_embeddings(len(self.tokenizer))
        self.MAX_GEN_TOKENS = 4
        
    def answer(self, quests:List[str]):
        """ Get model answers to (premise, hypothesis) formulas """
        self.model.eval()
        in_prompts = self.tokenizer(quests, padding=True, return_tensors="pt").to(self.gpu_id)
        answers = self.tokenizer.batch_decode(self.model.generate(**in_prompts, max_new_tokens = self.MAX_GEN_TOKENS), skip_special_tokens=True)
        return answers

    def prob_formula(self, s1, s2, label=None):
        """ 
            Train: fine-tuning on whole prompts
            Inputs:
                s1                List[str]   formatted antecedent facts prompts
                s2              List[str]   formatted consequent facts prompts 
            Outputs:
                s1_probs           Tensor      LM probabilities of ant. facts
                s2_probs        Tensor      LM probabilities of cons. facts
                cond_s2_probs   Tensor      LM probabilities of cons. facts conditioned on antecedents (in-context assumption)
        """
        self.model.train()
        
        if self.is_seq2seq():
            return self.__seq2seq_prob_formula(s1, s2, label)

        elif self.is_decoder():
            return self.__decoder_prob_formula(s1, s2, label)

    def __seq2seq_prob_formula(self, s1, s2, label=None):
        """ 
            Train: fine-tuning on whole prompts
            Inputs:
                s1                List[str]   formatted antecedent facts prompts
                s2              List[str]   formatted consequent facts prompts 
            Outputs:
                s1_probs           Tensor      LM probabilities of ant. facts
                s2_probs        Tensor      LM probabilities of cons. facts
                cond_s2_probs   Tensor      LM probabilities of cons. facts conditioned on antecedents (in-context assumption)
        """
        positive_labels = [prompt_answer(self.get_model_type(), label) for idx in range(len(s1))]
        positive_answ = self.tokenizer(positive_labels, padding=True, return_tensors="pt")
        positive_masked_output = positive_answ.input_ids.masked_fill(
            positive_answ.input_ids == self.tokenizer.pad_token_id, -100
        )

        # ---- Factual: s1
        in_s1 = self.tokenizer(s1, padding=True, return_tensors="pt")
        outputs = self.model(input_ids=in_s1.input_ids.to(self.gpu_id), labels=positive_answ.input_ids.to(self.gpu_id))
        s1_probs, _ = seq2seq_get_target_probs(outputs.logits.to(self.gpu_id), positive_masked_output.to(self.gpu_id), should_reduce=False) 
        s1_probs = s1_probs.exp().unsqueeze(-1)

        # ---- Factual: s2
        in_s2 = self.tokenizer(s2, padding=True, return_tensors="pt")
        outputs = self.model(input_ids=in_s2.input_ids.to(self.gpu_id), labels=positive_answ.input_ids.to(self.gpu_id),)
        s2_probs, _ = seq2seq_get_target_probs(outputs.logits.to(self.gpu_id), positive_masked_output.to(self.gpu_id), should_reduce=False) 
        s2_probs = s2_probs.exp().unsqueeze(-1)  

        return s1_probs, s2_probs

    def __decoder_prob_formula(self, s1, s2, label=None):
        """ 
            Train: fine-tuning on whole prompts
            Inputs:
                s1                List[str]   formatted antecedent facts prompts
                s2              List[str]   formatted consequent facts prompts 
            Outputs:
                s1_probs           Tensor      LM probabilities of ant. facts
                s2_probs        Tensor      LM probabilities of cons. facts
                cond_s2_probs   Tensor      LM probabilities of cons. facts conditioned on antecedents (in-context assumption)
        """
        positive_labels = [prompt_answer(self.get_model_type(), label) for idx in range(len(s1))]
        # ---- Factual: s1
        s1_probs = gpt_get_target_probs(
            model=self.model, 
            tokenizer=self.tokenizer, 
            inputs=s1,
            targets=positive_labels,
            gpu_id=self.gpu_id
        )
        # ---- Factual: s2
        s2_probs = gpt_get_target_probs(
            model=self.model, 
            tokenizer=self.tokenizer, 
            inputs=s2,
            targets=positive_labels,
            gpu_id=self.gpu_id
        )
        return s1_probs, s2_probs

    def get_perplexity(self, data, window_size=512):
        """ Computes perplexity as in https://huggingface.co/docs/transformers/en/perplexity """
        encodings = self.tokenizer("\n\n".join(data), return_tensors="pt")
        # model configs
        max_length = self.model.config.max_length
        seq_len = encodings.input_ids.size(1)
        nlls = []
        prev_end_loc = 0
        for begin_loc in tqdm(range(0, seq_len, window_size)):
            end_loc = min(begin_loc + max_length, seq_len)
            trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
            input_ids = encodings.input_ids[:, begin_loc:end_loc].to(self.gpu_id)
            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100
            with torch.no_grad():
                outputs = self.model(input_ids, labels=target_ids)
                neg_log_likelihood = outputs.loss
            nlls.append(neg_log_likelihood)
            prev_end_loc = end_loc
            if end_loc == seq_len:
                break
        return torch.exp(torch.stack(nlls).mean()).item()
        
    def get_model_type(self):
        return hf_models[self.model_hf_name]["type"]

    def is_decoder(self):
        return hf_models[self.model_hf_name]["type"] == "decoder"

    def is_seq2seq(self):
        return hf_models[self.model_hf_name]["type"] == "seq2seq"

def gpt_get_target_probs(model, tokenizer, inputs, targets, gpu_id):
    """ Token by token prediction (greedy sampling) and probs given target """
    # B = batch size, L = seq.len, D = dict size
    input_ids = tokenizer(inputs, padding=True, return_tensors="pt").to(gpu_id).input_ids # B, L
    target_ids = tokenizer(targets, padding=True, return_tensors="pt").to(gpu_id).input_ids # B, L
    target_toks = target_ids[:, 1:]
    # answers = model.generate(input_ids, max_new_tokens=target_toks.shape[1])
    # print(tokenizer.batch_decode(answers))
    probs = None
    # forward() automatically shifts by 1 and computes loss
    logits = model(input_ids=input_ids).logits.softmax(-1) # B, L, D
    # gather last token logits with the target ids
    target_probs = torch.gather(logits[:, -1], 1, target_toks[:, 0].unsqueeze(-1))
    return target_probs

def gpt_get_seq_target_probs(model, tokenizer, inputs, targets, gpu_id):
    """ Token by token prediction (greedy sampling) and probs given target """
    input_ids = tokenizer(inputs, padding=True, return_tensors="pt").to(gpu_id).input_ids # B, L
    target_ids = tokenizer(targets, padding=True, return_tensors="pt").to(gpu_id).input_ids # B, L
    target_toks = target_ids[:, 1:]
    probs = None
    for idx in range(target_toks.shape[1]):
        logits = model(input_ids=input_ids).logits.log_softmax(-1) # B, L, D
        # forward() automatically shifts by 1 and computes loss
        input_ids = torch.hstack((input_ids, logits[:, -1].argmax(-1).unsqueeze(-1))) # B, L, 1
        if probs is None: probs = target_probs
        else: probs = torch.hstack((probs, target_probs))
    return probs.sum(-1)

def mask_hf_labels(labels, null_token=0):
    valid_mask = labels != -100
    valid_labels = labels.masked_fill(~valid_mask, null_token)
    return valid_mask, valid_labels

def gather_log_probs(logits, labels):
    assert labels.dim() == logits.dim() - 1
    assert labels.shape == logits.shape[:-1]
    # softmax over dictionary size, labels are the indices (takes an array of these probabilities)
    return logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1)

def seq2seq_get_target_probs(pred, targets, should_reduce=True):
    NULL_TOKEN = 0  # a placeholder used for masked target locations

    pred = pred.clone()
    # mask out padding tokens
    mask, targ = mask_hf_labels(targets)
    # over the dictionary, take only the generated tokens' probabilities
    unmasked_log_probs = gather_log_probs(pred, targ)

    pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
    correct = pred_ids == targ
    if pred.dim() == 3:
        correct = (pred_ids == targ).all(
            -1
        )  # We want to get the whole sequence right
    acc = correct.float()

    if should_reduce:
        acc = acc.mean()

    # default: no mean reduction
    if should_reduce:
        log_probs = (unmasked_log_probs * mask.float()).mean(-1)
    else:
        log_probs = (unmasked_log_probs * mask.float()).sum(-1)
        
    return log_probs, acc
