import os
import json
from collections import defaultdict,Counter

import numpy as np
from scipy.stats import kendalltau, skew
from tqdm import tqdm as tqdm
import pandas as pd
from constants import * 

def get_judge_name(judge): #Map judges to their category
    if isinstance(judge, list) and judge[0] == "gpt-4" and judge[1].startswith("pair"):
        return "gpt4_pair"
    if judge.startswith("expert"): #or judge.startswith("author"):
        return "human" 
    if judge.startswith("author"):
        return "author"
    return judge
    
def revert(vote): #Fully revert a vote
    if vote == "model_a":
        return "model_b"
    elif vote == "model_b":
        return "model_a"
    return vote


def get_mt_bench_votes_data(raw_votes):
    data = [{}, {}]

    for judge_votes in raw_votes:
        for vote in judge_votes:
            turn = vote["turn"] - 1
            if vote["model_a"] < vote["model_b"]:
                key = (vote["question_id"], vote["model_a"], vote["model_b"])
                winner = vote["winner"]
            else:
                key = (vote["question_id"], vote["model_b"], vote["model_a"])
                winner = revert(vote["winner"])
            judge = get_judge_name(vote["judge"])
            if key not in data[turn]:
                data[turn][key] = {}
            if judge not in data[turn][key]:
                data[turn][key][judge] = []
            data[turn][key][judge].append(winner)

    return data

def aggregate(x):
    if type(x) == list:
        counter = Counter(x)
    else:
        return np.nan
    if "model_a" not in counter.keys():
        counter["model_a"] = 0
    if "model_b" not in counter.keys():
        counter["model_b"] = 0
    
    if counter["model_a"]>counter["model_b"]:
        return "model_a"
    elif counter["model_a"]<counter["model_b"]:
        return "model_b"
    else: 
        return "tie"
    


def get_mt(ties=False,majority=False):
    judges, votefiles = ["gpt4_pair", "human"], ["gpt4_pair_judgments.json", "human_judgments.json"]
    votes = []
    for filename in votefiles:
        data = []
        for line in open(filename, "r"):
            data.append(json.loads(line))
        votes.append(data)

    data = get_mt_bench_votes_data(votes)
    keys = ["id","model_a","model_b"]
    turn1_pd = [{**{keys[i]:key[i] for i in range(3) },**(data[0][key])} for key in data[0]]
    a = pd.DataFrame(turn1_pd)
    del a["author"]

    if not ties:
        score_dict = {"model_a":1,"model_b":0,"tie":np.nan,"tie (inconsistent)":np.nan}  #Make sure ties are excluded
    else:
        score_dict = {"model_a":1,"model_b":0,"tie":0.5,"tie (inconsistent)":0.5}    
    if not majority:
        a = a.explode("human").explode("gpt4_pair")
        a["human"] =  a["human"].map(score_dict)
        a["gpt4_pair"] =  a["gpt4_pair"].map(score_dict)
        a = a.dropna()


    else:
        a["human"] =  a["human"].apply(aggregate)
        a["gpt4_pair"] =  a["gpt4_pair"].apply(aggregate)
        a["human"] =  a["human"].map(score_dict)
        a["gpt4_pair"] =  a["gpt4_pair"].map(score_dict)
        a = a.dropna()
    return a 

def get_targets(task_name):
    model_name = "openai_gpt-4o-2024-05-13"
    res = [
        i
        for i in file_list
        if i.startswith("mmlu:subject=%s" % task_name) and model_name in i
    ]
    assert len(res) == 1

    n = 0
    for k in res:
        path_instance = os.path.join(
            dir_predictions,
            k,
            "instances.json",
        )
        instances = json.load(open(path_instance, "r"))
        n += len(instances)
    
    ret = {}
    for k in res:
        path_instance = os.path.join(
            dir_predictions,
            k,
            "instances.json",
        )
        instances = json.load(open(path_instance, "r"))
        for i in instances:
            idx = ("%s-%s" % (k.replace(model_name, "?"), i["id"])).replace(
                ",stop=none", ""
            )
            assert idx not in ret

            a = {k: int(len(i["references"][k]["tags"])==1 and i["references"][k]["tags"][0]=="correct") for k in range(len(i["references"]))}
            ret[idx] = a
    return pd.DataFrame(ret).transpose()

def make_deterministic(df):
    argmax_indices = df.values.argmax(axis=1)

    # Create a one-hot encoded array
    one_hot_encoded = np.zeros_like(df.values)
    one_hot_encoded[np.arange(len(df)), argmax_indices] = 1

    return pd.DataFrame(one_hot_encoded, columns=df.columns)


file_list = os.listdir(dir_predictions)

def get_predictions(model_name, task_name):
    if not model_name.startswith("llama3.1"):
        res = [
            i
            for i in file_list
            if i.startswith("mmlu:subject=%s" % task_name) and model_name in i
        ]
        assert len(res) == 1

        n = 0
        for k in res:
            path_instance = os.path.join(
                dir_predictions,
                k,
                "instances.json",
            )
            instances = json.load(open(path_instance, "r"))
            n += len(instances)
        
        ret = {}
        for k in res:
            path_instance = os.path.join(
                dir_predictions,
                k,
                "instances.json",
            )
            path_pred = os.path.join(
                dir_predictions,
                k,
                "display_predictions.json",
            )
            instances = json.load(open(path_instance, "r"))
            predictions = json.load(open(path_pred, "r"))
            for i, j in zip(instances, predictions):
                assert i["id"] == j["instance_id"]
                idx = ("%s-%s" % (k.replace(model_name, "?"), i["id"])).replace(
                    ",stop=none", ""
                )
                
                assert idx not in ret
                
                predmap = defaultdict(lambda: 0, {" A":0," B":1, " C":2, " D":3,"A":0,"B":1,"C":2,"D":3})
                a = {k: 0.0  for k in range(len(i["references"]))}
                a[predmap[j["predicted_text"]]] = 1.0   
                ret[idx] = a
                
        return pd.DataFrame(ret).transpose()
    elif model_name.endswith("_det"):
        return make_deterministic(pd.read_json(dir_predictions+"/Meta-Llama-3.1-405B_"+task_name+"/predictions_proba.json").set_index("instance_id"))
    else:
        return pd.read_json(dir_predictions+"/Meta-Llama-3.1-405B_"+task_name+"/predictions_proba.json").set_index("instance_id")



def get_all_targets():
    outs = []
    for task in [task.split("subject=")[1].split(",")[0] for task in meta_predictions.split("\n")]:
        outs.append(get_targets(task))
    return pd.concat(outs)


def randomize_labels(one_hot_array):
    
    # Get the number of samples and the number of labels
    num_samples, num_labels = one_hot_array.shape
    
    # Get the original labels by finding the index of the max value in each row
    original_labels = np.argmax(one_hot_array, axis=1)
    
    # Create an array of possible new labels (excluding the original ones)
    all_labels = np.arange(num_labels)
    possible_labels = np.tile(all_labels, (num_samples, 1))
    
    # Mask out the original labels
    mask = np.ones(possible_labels.shape, dtype=bool)
    mask[np.arange(num_samples), original_labels] = False
    possible_labels = possible_labels[mask].reshape(num_samples, num_labels - 1)
    
    # Randomly choose a new label from the remaining labels
    random_indices = np.random.randint(num_labels - 1, size=num_samples)
    new_labels = possible_labels[np.arange(num_samples), random_indices]
    
    # Create a new one-hot encoded array with the new labels
    new_one_hot_array = np.zeros_like(one_hot_array)
    new_one_hot_array[np.arange(num_samples), new_labels] = 1
    
    return new_one_hot_array



targets = get_all_targets()
def get_all_targets():
    return targets.values


def get_all_predictions(model_name):
    outs = []
    for task in [task.split("subject=")[1].split(",")[0] for task in meta_predictions.split("\n")]:
        outs.append(get_predictions(model_name,task))
    return pd.concat(outs).values

preds = {model: get_all_predictions(model) for model in model_list+["llama3.1","llama3.1_det"]}
def get_all_predictions(model_name):
    if type(model_name)==str and model_name in preds:
        return preds[model_name]
    elif type(model_name)==str and model_name == "perfect":
        return get_all_targets()
    elif type(model_name)==tuple: 
        delta = float(model_name[1])
        model_name = model_name[0]
        preds_model = get_all_predictions(model_name)
        preds_true = get_all_targets()
        
        
        acc = (preds_model*preds_true).sum(1).mean()

        
        p_flip = delta/(1-acc)
        assert p_flip<1

        mask = np.random.binomial(1,p_flip,len(preds_true)).astype(bool)
        
        out = preds_model.copy()
        out[mask] = preds_true[mask]
        return out 

    elif type(model_name) == float:
        preds_true = get_all_targets()
        preds_new = randomize_labels(preds_true)

        p_flip =  model_name

        mask = np.random.binomial(1,p_flip,len(preds_true)).astype(bool)
        preds_new[mask] = preds_true[mask]

        out = preds_true.copy()
        out[:,:] = preds_new[:,:] 
        
        return out 
    
def get_clean_dict(modelset=None):
        
    if modelset is None:
        models = model_list
    else:
        models = modelset
    b = get_all_targets()
    clean_dict = {}
    for model_name in models:
        a = get_all_predictions(model_name)
        clean_dict[model_name] = (a*b).sum(1).mean()
    return clean_dict


def get_evals(base_model,modelset = None):
    if modelset is None:
        models = model_list
    else:
        models = modelset
    

    b_model = get_all_predictions(base_model)
       
    
    out_dict = {}
    
    for model_name in models:
        preds = get_all_predictions(model_name)
        
        out_dict[model_name] = (preds * b_model).sum(1).mean()
    return out_dict

def get_model_eval(base_model,modelset=None):
    return get_evals(base_model,modelset = modelset)