import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
# from model import tokenization_with_bert, BertClassifier
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizer
import pickle

from src.utils.ft_pytorch.squad import *
from src.core.interface import dialog, ranking, annotation

class EnsembleUnion(nn.Module):
    def __init__(self, model_name, ner_path, squad_path):
        super(EnsembleUnion, self).__init__()
        self.ner_model = annotation.load_ner_model(model_name, task='ner', model_path=ner_path)
        self.squad_model = annotation.load_ner_model(model_name, task='squad', model_path=squad_path)
        self.model_name = model_name
        

    def forward(self, input_text, question, feature):
        # output_squad = self.squad_model(input_ids=input_ids, attention_mask=attention_mask)
        # output_ner = self.ner_model(input_ids=input_ids, attention_mask=attention_mask)

        output_squad = annotation.annotate_text(input_text, active_feature=feature, model_name=self.model_name, model=self.squad_model, task='squad', question=question)
        output_ner = annotation.annotate_text(input_text, active_feature=feature, model_name=self.model_name, model=self.ner_model, task='ner', question=question)

        squad_text = sum([item['text'].split() for item in output_squad], [])
        ner_text = sum([item['text'].split() for item in output_ner], [])

        if len(ner_text) == 0:
            output = squad_text
        elif len(squad_text) == 0:
            output = ner_text
        output = set(squad_text).union(set(ner_text))
        return list(output)

class EnsembleIntersection(nn.Module):
    def __init__(self, model_name, ner_path, squad_path):
        super(EnsembleIntersection, self).__init__()
        self.ner_model = annotation.load_ner_model(model_name, task='ner', model_path=ner_path)
        self.squad_model = annotation.load_ner_model(model_name, task='squad', model_path=squad_path)
        self.model_name = model_name
        

    def forward(self, input_text, question, feature):
        # output_squad = self.squad_model(input_ids=input_ids, attention_mask=attention_mask)
        # output_ner = self.ner_model(input_ids=input_ids, attention_mask=attention_mask)

        output_squad = annotation.annotate_text(input_text, active_feature=feature, model_name=self.model_name, model=self.squad_model, task='squad', question=question)
        output_ner = annotation.annotate_text(input_text, active_feature=feature, model_name=self.model_name, model=self.ner_model, task='ner', question=question)

        squad_text = sum([item['text'].split() for item in output_squad], [])
        ner_text = sum([item['text'].split() for item in output_ner], [])

        if len(ner_text) == 0:
            output = squad_text
        elif len(squad_text) == 0:
            output = ner_text
        output = set(squad_text).intersection(set(ner_text))
        return list(output)

def get_feature_from_question(question):
    question_map = {
        "what is the target attribute?": "attribute",
        "what is the target aggregator?": "aggregator",
        "what is the target filtering attribute?": "filter",
        "what is the target filtering operator?": "filter_operation",
        "what is the prediction window?": "prediction_window",
        "what number is used?": "number",
        "what is the target entity?": "entity"
    }

    return question_map[question]

def evaluation(model, queries, questions, answers, file):
    # gpu_ids = [i for i in range(torch.cuda.device_count())]
    # print(gpu_ids)
    # model = torch.nn.DataParallel(model, device_ids=gpu_ids)

    # model.to(gpu_ids[0])

    model.eval()
    
    acc = []
    predictions = []
    ground_truths = []

    # initialize loop for progress bar
    # loop through batches
    for i in tqdm(range(len(queries))):
        # we don't need to calculate gradients as we're not training
        with torch.no_grad():
            # pull batched items from loader
            text = queries[i]
            question = questions[i]
            answer = answers[i]['text'].split()
            feature = get_feature_from_question(question)
            # make predictions
            outputs = model(text, question, feature)
            predictions.append(outputs)
            ground_truths.append(answer)
            # pull preds ou
            # calculate accuracy for both and append to accuracy list

            
    # calculate average accuracy in total
    f1, precision, recall = compute_f1(predictions, ground_truths)
    print(f"\nAccuracy in F1: {f1}, \t Precision: {precision}, \t Recall: {recall}\n")
    file.write(f"\nAccuracy in F1: {f1}, \t Precision: {precision}, \t Recall: {recall}\n")



def get_union_intersection_f1():
    CONF_LOC = "src/core/evaluation/global_config.json"
    file = open("src/core/ensembles/acl/results_universal_model.txt", 'a')
    with open(CONF_LOC) as config_file:
        configuration = json.load(config_file)

        SCHEMA_NAMES = configuration["schema"]
        MODEL_LIST = configuration["ner_models"]
        
        for schema_name in SCHEMA_NAMES:
            
            print(f"**************__{schema_name}_**************")
            file.write(f"**************__{schema_name}_**************\n")
            for model_name in MODEL_LIST:
                print(f"**************__{model_name}_**************")
                file.write(f"**************__{model_name}_**************\n")
                ner_path = f"models/lm/ner/combined/{model_name}/{model_name}"
                squad_path = f"models/lm/squad/combined/{model_name}"
                union_model = EnsembleUnion(model_name, ner_path=ner_path, squad_path=squad_path)
                intersection_model = EnsembleIntersection(model_name, ner_path=ner_path, squad_path=squad_path)

                data_path = f"src/data/test_data/squad_format/{schema_name}.json"
                batch_size = 32
                contexts, questions, answers = read_squad(data_path, batch_size)
                print("------------union-----------------")
                file.write("------------union-----------------")
                evaluation(union_model, contexts, questions, answers, file)
                print("------------intersection-----------------")
                file.write("------------intersection-----------------")
                evaluation(intersection_model, contexts, questions, answers, file)
        

    file.close()



get_union_intersection_f1()