import os
from tqdm import tqdm
import numpy as np
from typing import List
from dataclasses import dataclass
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from tqdm import tqdm, trange
from utils import chat_gpt_wrapper, gpt3wrapper_texts_batch_iter


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_TARGET_LENGTH = 2
YES_NO_TOK_IDX = [150, 4273]
MAX_SOURCE_LENGTH = 1024
TEMPERATURE = 0.001
sm = torch.nn.Softmax(dim=-1)
GPT_TEMPLATE = "templates/gpt_validator.txt"
GPT_LOOSE_TEMPLATE = "templates/gpt_validator_loose.txt"


@dataclass
class ValidatorInput:
    """A validator input, consisting of a hypothesis and a text."""

    hypothesis: str
    text: str


class Validator:
    """A validator to validate a hypothesis given a text; abstract class."""

    def obtain_scores(self, validator_inputs: List[ValidatorInput]):
        raise NotImplementedError


class D5Validator(Validator):
    """A validator based on T5 model to validate a hypothesis given a text"""

    BATCH_SIZE = 16
    with open("templates/t5_validator.txt", "r") as f:
        DEFAULT_VALIDATOR_TEMPLATE = f.read()

    def __init__(
        self,
        model_path: str = None,
        batch_size: int = BATCH_SIZE,
        template: str = DEFAULT_VALIDATOR_TEMPLATE,
    ):
        """
        Initialize the validator

        Parameters
        ----------
        model_path : str
            The path to the T5 model weights used for validation
        batch_size : int
            The batch size used for validation
        template : str
            The template used for validation
        """

        self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
        print("loading model weights")
        self.model = T5ForConditionalGeneration.from_pretrained(model_path)
        self.model_name = os.path.basename(model_path)
        print("done")
        self.parallelize_across_device()
        self.validator_template = template
        self.batch_size = batch_size

    def parallelize_across_device(self):
        """Parallelize the model across devices if multiple GPUs are available"""
        num_heads = len(self.model.encoder.block)
        num_device = torch.cuda.device_count()
        other_device_alloc = num_heads // num_device + 1
        first_device = num_heads - (num_device - 1) * other_device_alloc
        device_map = {}
        cur = 0
        end = max(cur + first_device, 1)
        device_map[0] = list(range(cur, end))
        cur = end
        for i in range(1, num_device):
            end = min(cur + other_device_alloc, num_heads)
            device_map[i] = list(range(cur, end))
            cur += other_device_alloc
        print("device_map", device_map)
        self.model.parallelize(device_map)

    def obtain_scores(
        self, validator_inputs: List[ValidatorInput], verbose: bool = False
    ) -> List[float]:
        """
        Given a list of ValidatorInput, return a list of scores, which the i-th score is how well the i-th ValidatorInput satisfies the description.

        Parameters
        ----------
        validator_inputs : List[ValidatorInput]
            A list of ValidatorInput.
        verbose : bool
            Whether to show a progress bar.

        Returns
        -------
        List[float]
            A list of scores, which the i-th score is how well the i-th ValidatorInput satisfies the description.
        """

        prompts = []
        for validator_dict in validator_inputs:
            prompt = self.validator_template.format(
                hypothesis=validator_dict.hypothesis, text=validator_dict.text
            )
            prompts.append(prompt)
        with torch.no_grad():
            self.model.eval()
            num_batches = (len(prompts) - 1) // self.batch_size + 1
            if verbose:
                pbar = trange(num_batches)
                pbar.set_description("inference")
            else:
                pbar = range(num_batches)

            for batch_idx in pbar:
                input_prompts = prompts[
                    batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size
                ]
                inputs = self.tokenizer(
                    input_prompts,
                    return_tensors="pt",
                    padding="longest",
                    max_length=MAX_SOURCE_LENGTH,
                    truncation=True,
                ).to(device)
                generation_result = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    do_sample=True,
                    temperature=0.001,
                    max_new_tokens=1,
                    return_dict_in_generate=True,
                    output_scores=True,
                )
                scores = (
                    sm(generation_result.scores[0][:, YES_NO_TOK_IDX])[:, 1]
                    .detach()
                    .cpu()
                    .numpy()
                    .tolist()
                )
                for s in scores:
                    yield s


class GPTValidator(Validator):
    def __init__(self, model: str, template_path: str = GPT_TEMPLATE):
        """
        Parameters
        ----------
        model : str
            The GPT model to use.
        """
        super().__init__()
        self.model = model
        self.model_name = model
        if "loose" in model:
            template_path = GPT_LOOSE_TEMPLATE
            print("Using loose template")

        with open(template_path, "r") as f:
            self.template = f.read()

    def obtain_scores(
        self, validator_inputs: List[ValidatorInput], verbose: bool = False
    ):
        """
        Given a list of ValidatorInput, return a list of scores, which the i-th score is how well the i-th ValidatorInput satisfies the description.

        Parameters
        ----------
        validator_inputs : List[ValidatorInput]
            A list of ValidatorInput.
        verbose : bool
            Whether to print progress bar

        Returns
        -------
        List[float]
            A list of scores, which the i-th score is how well the i-th ValidatorInput satisfies the description.
        """

        # construct the prompts
        prompts = [
            self.template.format(hypothesis=input.hypothesis, text=input.text)
            for input in validator_inputs
        ]
        if self.model in ("gpt-4", "gpt-3.5-turbo"):
            for prompt in prompts:
                response = chat_gpt_wrapper(
                    prompt=prompt, model=self.model, temperature=0.0
                )
                yield 1 if "yes" in response[0].lower() else 0
        elif self.model.startswith("text-davinci"):
            for text_response in gpt3wrapper_texts_batch_iter(
                prompt=prompts, model=self.model, temperature=0.0, verbose=verbose
            ):
                yield 1 if "yes" in text_response.lower() else 0


def get_validator_by_name(validator_name):
    if "t5" in validator_name:
        return D5Validator(validator_name)
    elif "text-davinci" in validator_name:
        return GPTValidator(validator_name)
    elif validator_name in ("gpt-4", "gpt-3.5-turbo"):
        return GPTValidator(validator_name)
    else:
        raise ValueError(f"Unknown validator {validator_name}")


def validate_descriptions(
    descriptions: List[str],
    texts: List[str],
    validator: Validator,
    progress_bar: bool = False,
) -> np.ndarray:
    """
    Given a list of descriptions and a list of texts, return a matrix of scores, which the i-th row and j-th column is how well the i-th text satisfies the j-th description.

    Parameters
    ----------
    descriptions : List[str]
        A list of descriptions to be validated.
    texts : List[str]
        A list of texts to be validated.
    validator : Validator
        A validator that can validate a list of ValidatorInput. Could be either a T5Validator or a GPT35Validator.
    progress_bar : bool, optional
        Whether to show a progress bar, by default False

    Returns
    -------
    np.ndarray
        A matrix of scores, which the i-th row and j-th column is how well the i-th text satisfies the j-th description.
    """
    # aggregate all the validator inputs
    validator_inputs = []
    for text in texts:
        for description in descriptions:
            validator_inputs.append(ValidatorInput(hypothesis=description, text=text))

    # obtain the scores
    # scores = list(validator.obtain_scores(validator_inputs))
    scores = []

    # if progress_bar:
    #     pbar = tqdm(range(0, len(validator_inputs)), desc="validating")
    #     for score in validator.obtain_scores(validator_inputs):
    #         scores.append(score)
    #         pbar.update(1)
    # else:
    #     for score in validator.obtain_scores(validator_inputs):
    #         scores.append(score)
    for score in validator.obtain_scores(validator_inputs, verbose=progress_bar):
        scores.append(score)

    # reshape the scores into a matrix
    # the i-th row and j-th column is how well the i-th text satisfies the j-th description
    scores = np.array(list(scores)).reshape(len(texts), len(descriptions))
    return scores


if __name__ == "__main__":

    texts = [
        "I like this film.",
        "I hate this film.",
    ]
    descriptions = [
        "is positive in tone",
        "is negative in tone",
    ]

    # gpt_name = "gpt-3.5-turbo"
    gpt_name = "text-davinci-002"
    validator = GPTValidator(gpt_name)
    scores = validate_descriptions(descriptions, texts, validator)
    print(scores)
    exit(0)
