from tqdm import tqdm
from typing import Dict, List, Set, Tuple, Union
import numpy as np
import os
import time 
import os
import json
from datetime import datetime
from tqdm import tqdm
from transformers import AutoTokenizer
import pytz
import argparse
import itertools

import numpy as np
from typing import Dict
import torch
from lm_polygraph.estimators.estimator import Estimator
from lm_polygraph.utils.deberta import MultilingualDeberta

from absl import logging as absl_logging

class LUQ_vllm:

    def __init__(
        self,
        model = "llama3-8b-instruct",
        method = "binary",
        abridged = False,
    ):
        """
        model: str
            The model to use. Currently only "llama3-8b-instruct" is supported. If you want to use other more lightweight models, please revise the codes accordingly.
        method: str
            The method to use. Currently only "binary" and "multiclass" are supported. We recommend using "binary" for simplicity.
        abridged: bool
            To have some results quicklier. If True, the function will return the score of the first sentence only. The score then represents the model's confidence in the first sentence given a fixed prompt. 
        """

        if model == "llama3-8b-instruct":
            model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
        if model == "llama3.2-1b-instruct":
            model_path = "meta-llama/Llama-3.2-1B-Instruct"
        else:
            raise ValueError("Model not supported")

        self.method = method
        self.abridged = abridged
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

        # Define sampling parameters
        self.sampling_params = SamplingParams(
            n=1,
            temperature=0,
            top_p=0.9,
            max_tokens=5,
            stop_token_ids=[self.tokenizer.eos_token_id],
            skip_special_tokens=True,
        )

        self.llm = LLM(model=model_path, tensor_parallel_size=1, gpu_memory_utilization=0.2)
        
        if self.method == "binary":
            self.prompt_template = "Context: {context}\n\nSentence: {sentence}\n\nIs the sentence supported by the context above? Answer Yes or No.\n\nAnswer: "
            self.text_mapping = {'yes': 1, 'no': 0, 'n/a': 0.5}
        elif self.method == "multiclass":
            self.prompt_template = (
                "Context: {context}\n\n"
                "Sentence: {sentence}\n\n"
                "Is the sentence supported, refuted or not mentioned by the context above? "
                "You should answer the question purely based on the given context. "
                "Do not output the explanations.\n\n"
                "Your answer should be within \"supported\", \"refuted\", or \"not mentioned\".\n\n"
                "Answer: "
            )
            self.text_mapping = {'supported': 1, 'refuted': 0, 'not mentioned': -1, 'n/a': 0.5}

        self.not_defined_text = set()


    def set_prompt_template(self, prompt_template: str):
        self.prompt_template = prompt_template
    
    def completion(self, prompts: str):
        outputs = self.llm.generate(prompts, self.sampling_params, use_tqdm=False)
        return outputs

    def predict(
        self,
        sentences: List[str],
        sampled_passages: List[List[str]],
        verbose: bool = False,
    ):
        all_samples = [sentences] + sampled_passages

        luq_scores = np.zeros(len(all_samples))
        for index, item in enumerate(all_samples):
            # samples = [" ".join(sample) for sample in all_samples if sample != item]
            samples = [" ".join(sample) for k, sample in enumerate(all_samples) if k != index]

            num_sentences = len(item)
            num_samples = len(samples)
            scores = np.zeros((num_sentences, num_samples))
            
            for sent_i in range(num_sentences):
                prompts = []
                sentence = item[sent_i]
                for sample_i, sample in enumerate(samples):
                    sample = sample.replace("\n", " ") 
                    prompt_text = self.prompt_template.format(context=sample, sentence=sentence)

                    # print(prompt_text)
                    messages = [
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": prompt_text}
                    ]

                    prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

                    prompts.append(prompt)

                outputs = self.completion(prompts)

                for sample_i, output in enumerate(outputs):
                    generate_text = output.outputs[0].text
                    # print(generate_text)
                    score_ = self.text_postprocessing(generate_text)
                    # print(score_)
                    scores[sent_i, sample_i] = score_

            scores_filtered = np.ma.masked_equal(scores, -1)
            scores_per_sentence = scores_filtered.mean(axis=-1)
            scores_per_sentence = np.where(scores_per_sentence.mask, 0, scores_per_sentence)
            # Calculate the average score for each sentence
            luq_scores[index] = scores_per_sentence.mean()
            # print(scores_per_sentence)
            if self.abridged:
                return scores_per_sentence.mean()
        return luq_scores.mean()
        

    def text_postprocessing(
        self,
        text,
    ):
        """
        To map from generated text to score
        """

        if self.method == "binary":
            text = text.lower().strip()
            if text[:3] == 'yes':
                text = 'yes'
            elif text[:2] == 'no':
                text = 'no'
            else:
                if text not in self.not_defined_text:
                    print(f"warning: {text} not defined")
                    self.not_defined_text.add(text)
                text = 'n/a'
            return self.text_mapping[text]
        
        elif self.method == "multiclass":
            text = text.lower().strip()
            if text[:7] == 'support':
                text = 'supported'
            elif text[:5] == 'refut':
                text = 'refuted'
            elif text[:3] == 'not':
                text = 'not mentioned'
            else:
                if text not in self.not_defined_text:
                    print(f"warning: {text} not defined")
                    self.not_defined_text.add(text)
                text = 'n/a'
            return self.text_mapping[text]
            # return text

class LUQ(Estimator):
    def __init__(self, model: str = "llama3.2-1b-instruct"):

        self.model = model
        if "llama" in self.model:
            self.scorer = LUQ_vllm(model = self.model, method = "binary", abridged = False)
        else:
            self.scorer = MultilingualDeberta("cross-encoder/nli-deberta-v3-large", batch_size=1)
            
        super().__init__(["sample_texts"], "sequence")

    def __str__(self):
        return f"LUQ ({self.model})"

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        greedy_texts = stats["greedy_texts"]
        batch_texts = stats["sample_texts"]


        scores = []
        if "llama" in self.model:
            for greedy_text, texts in zip(greedy_texts, batch_texts):
                scores.append(1 - self.scorer.predict([greedy_text], [[t] for t in texts]))
        else:
            deberta = self.scorer
            deberta_batch_size = deberta.batch_size
            batch_texts = [[greedy_text] + texts for greedy_text, texts in zip(greedy_texts, batch_texts)]

            batch_pairs = []
            batch_invs = []
            batch_counts = []
            for texts in batch_texts:
                # Sampling from LLM often produces significant number of identical
                # outputs. We only need to score pairs of unqiue outputs
                unique_texts, inv = np.unique(texts, return_inverse=True)
                batch_pairs.append(list(itertools.product(unique_texts, unique_texts)))
                batch_invs.append(inv)
                batch_counts.append(len(unique_texts))

            device = deberta.device
            ent_id = deberta.deberta.config.label2id["ENTAILMENT"]
            contra_id = deberta.deberta.config.label2id["CONTRADICTION"]

            tokenizer = deberta.deberta_tokenizer

            E = []
            C = []

            for i, pairs in enumerate(batch_pairs):
                dl = torch.utils.data.DataLoader(pairs, batch_size=deberta_batch_size)
                probs = []
                for first_texts, second_texts in dl:
                    batch = list(zip(first_texts, second_texts))
                    encoded = tokenizer.batch_encode_plus(
                        batch, padding=True, return_tensors="pt"
                    ).to(device)
                    logits = deberta.deberta(**encoded).logits.detach().to(device)
                    probs.append(logits.cpu().detach())
                probs = torch.cat(probs, dim=0)

                entail_probs = probs[:, ent_id]
                contra_probs = probs[:, contra_id]

                unique_mat_shape = (batch_counts[i], batch_counts[i])

                unique_E = entail_probs.view(unique_mat_shape).numpy()
                unique_C = contra_probs.view(unique_mat_shape).numpy()

                inv = batch_invs[i]

                # Recover full matrices from unques by gathering along both axes
                # using inverse index
                E.append(unique_E[inv, :][:, inv])
                C.append(unique_C[inv, :][:, inv])

            E = np.stack(E)
            C = np.stack(C)
            
            for j in range(len(greedy_texts)):
                sim_scores = np.exp(E[j]) / (np.exp(E[j]) + np.exp(C[j]))
                sim_scores = (sim_scores.sum(axis=1) - sim_scores.diagonal()) / (sim_scores.shape[-1] - 1)
                scores.append(1 - sim_scores.mean())
            
        return np.array(scores)
