import torch
from dataclasses import dataclass
from torchsummary import summary
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM
import numpy

@dataclass
class EvalResult:
    prompt: str
    solution: str
    answer: str
    accuracy: bool


class Evaluator:
    def __init__(self, pretrained, device = "cuda") -> None:
        '''
        pretrained (str):
            The HuggingFace Hub model ID name or the path to a pre-trained
            model to load. This is effectively the `pretrained_model_name_or_path`
            argument of `from_pretrained` in the HuggingFace `transformers` API.
        '''
        self.device = device
        self.model_id = pretrained
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained)
        self.model = AutoModelForCausalLM.from_pretrained(pretrained,torch_dtype=torch.float16, low_cpu_mem_usage=False, device_map="auto",attn_implementation="eager")
        self.model.eval().bfloat16()

    def get_model_id(self):
        return self.model_id

    def eval(self, prompt, answer, seq_len: int = 2) -> EvalResult:
        prompt = prompt.strip()
        solution= self.prompt_to_solution(prompt, seq_len = seq_len)#, hiddenlayers, attentions
        result = EvalResult(prompt, solution, answer, answer in solution)
        return result
    def eval2(self, prompt, answer, solution, seq_len: int = 2) -> EvalResult:
        prompt = prompt.strip()
        #solution= self.prompt_to_solution(prompt, seq_len = seq_len)#, hiddenlayers, attentions
        #psolution = solution.split(">")
        #ap = psolution[-1]
        result = EvalResult(prompt, solution, answer, answer in solution)
        return result
    def eval2distb(self, prompt, answer, solution, seq_len: int = 2) -> EvalResult:
        prompt = prompt.strip()
        #solution= self.prompt_to_solution(prompt, seq_len = seq_len)
        #psolution = solution.split(">")
        #ap = psolution[-1]
        result = EvalResult(prompt, solution, answer, (answer == (solution.strip())))
        return result#, hiddenlayers, attentions
    def eval2cot(self, prompt, answer, solution, seq_len: int = 2) -> EvalResult:
        prompt = prompt.strip()
        #solution= self.prompt_to_solution(prompt, seq_len = seq_len)
        psolution = solution.split(">")
        ap = psolution[-1]
        result = EvalResult(prompt, solution, answer, (answer == (ap.strip())))
        return result
    def eval2exp(self, prompt, answer, solution, seq_len: int = 2) -> EvalResult:
        prompt = prompt.strip()
        #solution= self.prompt_to_solution(prompt, seq_len = seq_len)
        psolution = solution.split("step3:")
        ap = psolution[-1]
        result = EvalResult(prompt, solution, answer, answer == (ap.strip()))
        return result
    def eval3(self, prompt, answer, solution, seq_len: int = 2) -> EvalResult:
        #prompt = prompt.strip()
        #solution= self.prompt_to_solution(prompt, seq_len = seq_len)#, hiddenlayers, attentions
        psolution = solution.split(">")
        ap = psolution[-1]
        result = EvalResult(prompt, solution, answer, answer == (ap.strip()))
        return result#, hiddenlayers, attentions
    def translation(self, prompt):
        tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        return tokens[0]
    def prompt_to_solution(self, prompt, seq_len: int = 2):
        # Encode the input prompt
        prompt = prompt.strip()
        tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        solution = ""
        # print("pro: ", prompt)
        prompt_len = len(prompt)
        for _ in range(seq_len):
            # Directly use the model's forward pass
            # print(tokens.shape)
            with torch.no_grad():
                outputs = self.model(input_ids=tokens, return_dict=True)
            logits = outputs.logits

            # Select the next token with the highest probability
            new_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)

            # Append the generated token to the tokens tensor
            tokens = torch.cat((tokens, new_token), dim=1)

            # print(new_token.shape)
            # print(new_token.squeeze().tolist())

            # Decode and append the generated token to the solution string
            new_word = self.tokenizer.decode(new_token.squeeze().tolist())
            if "\n" in new_word:
                break
            # print(solution)

        solution += self.tokenizer.decode(tokens[0], skip_special_tokens=True)
        solution = solution[prompt_len:]
        return solution
    def prompt_to_solutionexp(self, prompt, seq_len: int = 2):
        # Encode the input prompt
        prompt = prompt.strip()
        tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        solution = ""
        lentoken = len(tokens[0])
        #print(prompt)
        #print("pro: ", prompt)
        prompt_len = len(prompt)

        # print("tokens: ", tokens)
        entercount = 0
        for _ in range(seq_len):
            # Directly use the model's forward pass
            with torch.no_grad():
                outputs = self.model(input_ids=tokens, return_dict=True)#, output_hidden_states=True, output_attentions=True
            #outputo = outputs.hidden_states
            #attens = outputs.attentions
            #cttens = list(attens)
            logits = outputs.logits

            # Select the next token with the highest probability
            new_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)

            # Append the generated token to the tokens tensor
            tokens = torch.cat((tokens, new_token), dim=1)

            # Decode and append the generated token to the solution string
            new_word = self.tokenizer.decode(new_token.squeeze().tolist())
            #print("new_word:" + new_word)
            if "\n" in new_word:
                entercount += 1
            if entercount > 1:
                break

        #tokensans = tokens[0][lentoken:]
        #print(tokens[0][lentoken:])
        solution += self.tokenizer.decode(tokens[0][lentoken:], skip_special_tokens=True)
        #print("solu: " + solution)
        #solution = solution[prompt_len:]
        
        # print("solu: ", solution)
        return solution#, outputo, cttens
    def prompt_to_solution2(self, prompt, seq_len: int = 2):
        # Encode the input prompt
        #prompt = prompt.strip()
        tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        g = len(tokens[0])
        #print(g)
        solution = ""
        prompt_len = len(prompt)
        cttens = []
        for a_ in range(seq_len):
            with torch.no_grad():
                outputs = self.model(input_ids=tokens, return_dict=True, output_hidden_states=True, output_attentions=True)
            logits = outputs.logits
            outputo = outputs.hidden_states
            attens = outputs.attentions
            #print(attens)
            bttens = list(attens)
            dttens = []
            for item in bttens:
                dttens.append(item.cpu())
            if len(dttens[0][0][0]) == g:
                cttens = dttens.copy()
                print("given " + str(a_))
            new_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
            tokens = torch.cat((tokens, new_token), dim=1)
            new_word = self.tokenizer.decode(new_token.squeeze().tolist())
            if "\n" in new_word:
                break
        solution += self.tokenizer.decode(tokens[0], skip_special_tokens=True)
        solution = solution[prompt_len:]
        return solution, outputo, cttens
    
    
    
