import numpy as np
from abc import ABC, abstractmethod
from torch import nn
from source.language_models.prompt_templates import *
import torch
from source.controller.retriever.passage_retrieval import embed_queries
from munch import DefaultMunch
from data.dataset import config

QA_loss = nn.MSELoss()
TC_loss = nn.MSELoss()

class AbstractModel(ABC):
    def __init__(self,model,loss, verbose=False):
        self.model = model
        self.loss = loss
        self.verbose = verbose
        self.thought_generation_prompt = thought_generation_prompt
        self.document_query_prompt = document_query_prompt


    def predict(self, x, Ks):
        output = self.model(x,Ks)
        return output
    
    def evaluate(self,output,y):
        loss = self.loss(output,y)
        return loss
    

class QA_model(AbstractModel):
    def __init__(self,LLM, verbose=False):
        super().__init__(model=LLM,loss=QA_loss)
        self.dict_class_token = {"0" : 29900, "1" : 29896, "2" : 29906, "3" : 29941, "4" : 29946}
        
    def predict(self,x,Ks):
        outputs = []
        if Ks == []:
            Ks = [""]
        for K in Ks:
            if K == "":
                prompt = QA_nocontext_prompt(x)
            else:
                prompt = QA_prompt(x,K)
            output = self.model.score(prompt)[0]

            score_0 = output[self.dict_class_token["0"]]
            score_1 = output[self.dict_class_token["1"]]
            if score_0 > score_1:
                outputs += [0]
            else:
                outputs += [1]
        return outputs
    
    def predict_rational(self,x,Ks):
        outputs = []
        if Ks == []:
            Ks = [""]
        for K in Ks:
            prompt = QA_prompt(x,K)
            outputs += [self.model.query(prompt)[1][0]]
        return outputs

    def evaluate(self, x, Ks, y, class_number=2):
        loss = []
        target = np.zeros(class_number)
        target[y] = 1
        
        if Ks == []:
            Ks = [""]
        for K in Ks:
            if K == "":
                prompt = QA_nocontext_prompt(x)
            else:
                prompt = QA_prompt(x,K)
            result = self.model.score(prompt)[0]
            eval_vec = []
            for i in range(class_number):
                eval_vec += [result[self.dict_class_token[str(i)]]]
            
            eval_vec = torch.nn.functional.softmax(torch.tensor(eval_vec))
            loss += [1-self.loss(eval_vec,torch.tensor(target)).cpu().detach().numpy()]
        return loss

    def evaluate_heuristic(self, x, Ks):
        scores = []
        if Ks == []:
            Ks = [""]
        for K in Ks:
            if K == "":
                prompt = QA_score_nocontext_prompt(x)
            else:
                prompt = QA_score_prompt(x,K)
            result = self.model.score(prompt)[0]
            score_0 = max(result[self.dict_class_token["4"]],0)
            score_1 = max(result[self.dict_class_token["3"]],0)
            if self.verbose:
                print("score 0 : ", score_0)
                print("score 1 : ", score_1)
            scores += [min(score_1/(max(score_0+score_1,0.0000001)), score_1)]
        return scores





class QT_QA_model(AbstractModel):
 def __init__(self,LLM, verbose=False):
    super().__init__(model=LLM,loss=QA_loss, verbose=verbose)
    self.dict_class_token = {"1" : 29896, "2" : 29906, "3" : 29941, "4" : 29946}
    self.thought_generation_prompt = ade_thought_generation_prompt
    
def predict(self,x,Ks):
    outputs = []
    if Ks == []:
        Ks = [""]
    for K in Ks:
        if K == "":
            prompt = med_nocontext_prompt(x,K)
        else:
            prompt = med_prompt(x,K)
        result = self.model.query(prompt, temperature = 0.2, max_new_tokens=4)[1][0].split("\n")[0].lower().replace("mg","").replace("m","").strip()
        outputs += [result]
    return outputs

def evaluate(self, x, Ks, y):
    loss = []
    
    if Ks == []:
        Ks = [""]
    for K in Ks:
        if K == "":
            prompt = med_nocontext_prompt(x,K)
        else:
            prompt = med_prompt(x,K)
            
        result = self.model.query(prompt, temperature = 0.2, max_new_tokens=4)[1][0].split("\n")[0].lower().replace("mg","").replace("m","").strip()
    
        y = str(y).split("-")
        r = False
        for v in y:
            if v.strip() == result:
                r = True
        loss += [int(r)]
        
        print("y ",y)
        print("result ",result)
        print("loss ",loss)
        
        
    return loss

def evaluate_heuristic(self, x, Ks):
    scores = []
    if Ks == []:
        Ks = [""]
    for K in Ks:
        prompt = QA_score_prompt(x,K)
        result = self.model.score(prompt)[0]
        score_0 = max(result[self.dict_class_token["4"]],0)
        score_1 = max(result[self.dict_class_token["3"]],0)
        if self.verbose:
            print("score 0 : ", score_0)
            print("score 1 : ", score_1)
        scores += [min(score_1/(max(score_0+score_1,0.0000001)), score_1)]
    return scores




