from typing import List
from pathlib import Path
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from deepeval.models.base_model import DeepEvalBaseLLM
from deepeval.benchmarks.mmlu.mmlu import MMLU


class Mistral7B(DeepEvalBaseLLM):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def load_model(self):
        return self.model

    def generate2(self, prompt: str) -> str:
        model = self.load_model()

        device = "cpu"  # "cuda"  # the device to load the model onto

        model_inputs = self.tokenizer([prompt], return_tensors="pt").to(device)
        model.to(device)

        generated_ids = model.generate(
            **model_inputs, max_new_tokens=100, do_sample=False
        )
        return self.tokenizer.batch_decode(generated_ids)[0]

    def generate(self, prompt: str) -> str:
        model = self.load_model()

        device = "cpu"  # "cuda:5" # the device to load the model onto

        model_inputs = self.tokenizer([prompt], return_tensors="pt").to(device)
        model.to(device)

        logits = model(**model_inputs)["logits"]
        pred = torch.argmax(logits[:, -1:, :], axis=-1)
        X = model_inputs["input_ids"]
        for i in tqdm(range(512)):
            X = torch.concatenate([X, pred], axis=1)
            logits = model(input_ids=X)["logits"]
            pred = torch.argmax(logits[:, -1:, :], axis=-1)

        return self.tokenizer.batch_decode(X)[0]

    async def a_generate(self, prompt: str) -> str:
        return self.generate(prompt)

    # This is optional.
    def batch_generate(self, promtps: List[str]) -> List[str]:
        model = self.load_model()
        device = "cuda"  # the device to load the model onto

        model_inputs = self.tokenizer(promtps, return_tensors="pt").to(device)
        model.to(device)

        generated_ids = model.generate(
            **model_inputs, max_new_tokens=100, do_sample=True
        )
        return self.tokenizer.batch_decode(generated_ids)

    def get_model_name(self):
        return "Mistral 7B"


def main():
    base_dir = "/data_user/data/"

    benchmark = MMLU()
    model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-2b",
        cache_dir=Path(base_dir) / "input/cache_hugg",
        device_map="cpu",
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-2b", cache_dir=Path(base_dir) / "input/cache_hugg"
    )

    mistral_7b = Mistral7B(model=model, tokenizer=tokenizer)
    print(mistral_7b.generate("Tell me a joke"))


if __name__ == "__main__":
    main()
