from typing import Literal, List, Dict
from dataclasses import dataclass
from adapter import GPT3Adapter
from tqdm import tqdm
import numpy as np
import statistics
import pickle
from scipy import stats

from json import load
import random
from random import randrange
import re

import json

from data import CausalDataset, MoralDataset, Example, AbstractDataset, JsonSerializable, Example
from prompt import CausalJudgmentPrompt, MoralJudgmentPrompt, JudgmentPrompt
from evaluator import AccuracyEvaluatorWithAmbiguity, CorrelationEvaluator, RMSEEvaluator, AuROCEvaluator, MAEEvaluator, \
    CEEvaluator

@dataclass
class ExperimentResult(JsonSerializable):
    acc: float
    conf_interval: tuple[float, float]
    r: float
    p: float
    rmse: float
    mae: float
    auroc: float
    ce: float

@dataclass
class Persona(JsonSerializable):
    name: str
    category: str
    description: str

def evaluate_question(
    name, 
    description,
    ex: Example,
    adapter: GPT3Adapter,
    jp: JudgmentPrompt,
    method: str='yesno',
):
    instance = jp.apply(ex, name, description)
    choice_scores = adapter.adapt(instance, method=method)

    return(instance, choice_scores)

def sample(file, number) -> list[Persona]:
  openFile = open(f'personas/{file}.json')
  array = load(openFile)

  sample = []
  for i in range(number):
    index = randrange(1, len(array))
    regex_result = re.search("^([\w'.ñ-]+ [\w'.ñ-]+) ([\w\W]+?)$", array[index])
    name = regex_result.group(1)
    description = regex_result.group(2)
    sample.append(Persona(name, file, description))
    array.pop(index)

  return sample

def run_template_for_gpt3(
    name, 
    description,
    cd: AbstractDataset,
    adapter: GPT3Adapter,
    jp: JudgmentPrompt,
    method: str='yesno',
    chat_mode=False
):
    instances, all_choice_scores, all_label_dist = [], [], []
    # for ex in tqdm(cd):
    for ex in cd:
        if not chat_mode:
            instance, choice_scores = evaluate_question(name, description, ex, adapter, jp, method)
        else:
            instance, choice_scores = evaluate_chat_question(name, description, ex, adapter, jp, method)
        instances.append(instance)
        all_choice_scores.append(choice_scores)
        all_label_dist.append(ex.answer_dist)

    return instances, all_choice_scores, all_label_dist

def run_experiment(
    name, 
    description, 
    dataset,
    judgement_prompt_1,
    judgement_prompt_2,
    engine: str='text-davinci-002',
    chat_mode=False
):
    adapter = GPT3Adapter(engine=engine)

    all_instances, all_choice_scores, all_label_indices = [], [], []

    instances, choice_scores, label_indices = run_template_for_gpt3(
        name, description, dataset, adapter, judgement_prompt_1, method='yesno',
        chat_mode=chat_mode
    )
    all_instances.extend(instances)
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)


    instances, choice_scores, label_indices = run_template_for_gpt3(
        name, description, dataset, adapter, judgement_prompt_2, method='multiple_choice',
        chat_mode=chat_mode
    )
    all_instances.extend(instances)
    all_choice_scores.extend(choice_scores)
    all_label_indices.extend(label_indices)

    return all_instances, all_choice_scores, all_label_indices

def select_yes(choices):
  return choices[0]

def get_persona_performance(persona_to_data) -> Dict[str, str]:
    # out of 25 persona, is there one that's better?
    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    mae_evaluator = MAEEvaluator()
    auroc_evaluator = AuROCEvaluator()
    ce_evaluator = CEEvaluator()

    persona_to_perf = {}
    for persona_label, (instances, choice_scores, label_indices) in persona_to_data.items():
        acc, conf_interval = evaluator.evaluate(choice_scores, label_indices)
        corr, p = corr_evaluator.evaluate(choice_scores, label_indices)
        rmse = rmse_evaluator.evaluate(choice_scores, label_indices)
        mae = mae_evaluator.evaluate(choice_scores, label_indices)
        auroc = auroc_evaluator.evaluate(choice_scores, label_indices)
        ce = ce_evaluator.evaluate(choice_scores, label_indices)

        # calling json here so we can store properly
        persona_to_perf[persona_label] = ExperimentResult(acc, conf_interval, corr, p, rmse, mae, auroc, ce).json

    return persona_to_perf

def evaluate_chat_question(
    name,
    description,
    ex: Example,
    adapter: GPT3Adapter,
    jp: JudgmentPrompt,
    method: str='yesno',
):
    instance = jp.apply(ex)
    system_prompt = f"{name} {description}. {name} is given the scenario and asked a question. You should act as if you were {name} and answer the question."
    choice_scores = adapter.adapt(instance, method=method, system_prompt=system_prompt)

    return(instance, choice_scores)
def collect_gp4_data(
    total_sample: List[Persona],
    experiment: Literal["Causal", "Moral"],
    engine: str='gpt-4',
    ind_persona_result: bool = False
):
    # we need this because GPT-4/Claude/ChatGPT changed prompt format
    # so we design system prompt, and load in the normal question

    if experiment == "Causal":
        dataset = CausalDataset()
        judgement_prompt_1 = CausalJudgmentPrompt(
            "./prompts/exp1_causal_prompt.jinja"
        )
        judgement_prompt_2 = CausalJudgmentPrompt(
            "./prompts/exp1_causal_prompt_2.jinja"
        )
    elif experiment == "Moral":
        dataset = MoralDataset()
        judgement_prompt_1 = MoralJudgmentPrompt(
            "./prompts/exp1_moral_prompt.jinja"
        )
        judgement_prompt_2 = MoralJudgmentPrompt(
            "./prompts/exp1_moral_prompt_2.jinja"
        )
    else:
        raise Exception("Experiment is neither Causal nor Moral")

    all_instances, all_choice_scores, all_label_indices = [], [], []

    persona_to_data = {}

    for persona in tqdm(total_sample):
        instances, choice_scores, label_indices = run_experiment(
            persona.name,
            persona.description,
            dataset,
            judgement_prompt_1,
            judgement_prompt_2,
            engine=engine,
            chat_mode=True
        )
        all_instances.extend(instances)
        all_choice_scores.extend(choice_scores)
        all_label_indices.extend(label_indices)

        persona_to_data[persona.category + "_" + persona.name] = (instances, choice_scores, label_indices)

        print("Finished running experiment for persona: ", persona.name, " (", persona.category, ")")

    with open(f"../../results/preds/exp5_persona_{engine}_{experiment.lower()}_preds.pkl", "wb") as f:
        pickle.dump(persona_to_data, f)

    if ind_persona_result:
        persona_exp_results = get_persona_performance(persona_to_data)

    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    mae_evaluator = MAEEvaluator()
    auroc_evaluator = AuROCEvaluator()
    ce_evaluator = CEEvaluator()

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    mae = mae_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)
    ce = ce_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"{experiment} Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"{experiment} Correlation: {r:.4f} (p={p:.4f})")
    print(f"{experiment} RMSE: {rmse:.4f}")
    print(f"{experiment} MAE: {mae:.4f}")
    print(f"{experiment} AuROC: {auroc:.4f}")
    print(f"{experiment} CE: {ce:.4f}")

    if not ind_persona_result:
        return ExperimentResult(acc, conf_interval, r, p, rmse, mae, auroc, ce)  #, ttest_result
    else:
        return ExperimentResult(acc, conf_interval, r, p, rmse, mae, auroc, ce), persona_exp_results

def collect_data(
    total_sample: List[Persona],
    experiment: Literal["Causal", "Moral"],
    engine: str='text-davinci-002',
    ind_persona_result: bool = False
):
    if experiment == "Causal":
        dataset = CausalDataset()
        judgement_prompt_1 = CausalJudgmentPrompt(
            "./prompts/persona_prompts/exp1_causal_prompt.jinja"
        )
        judgement_prompt_2 = CausalJudgmentPrompt(
            "./prompts/persona_prompts/exp1_causal_prompt_2.jinja"
        )
    elif experiment == "Moral":
        dataset = MoralDataset()
        judgement_prompt_1 = MoralJudgmentPrompt(
            "./prompts/persona_prompts/exp1_moral_prompt.jinja"
        )
        judgement_prompt_2 = MoralJudgmentPrompt(
            "./prompts/persona_prompts/exp1_moral_prompt_2.jinja"
        )
    else:
        raise Exception("Experiment is neither Causal nor Moral")

    all_instances, all_choice_scores, all_label_indices = [], [], []

    persona_to_data = {}

    for persona in tqdm(total_sample):
        instances, choice_scores, label_indices = run_experiment(
            persona.name,
            persona.description,
            dataset,
            judgement_prompt_1,
            judgement_prompt_2,
            engine=engine
        )
        all_instances.extend(instances)
        all_choice_scores.extend(choice_scores)
        all_label_indices.extend(label_indices)

        persona_to_data[persona.category + "_" + persona.name] = (instances, choice_scores, label_indices)

        print("Finished running experiment for persona: ", persona.name, " (", persona.category, ")")

    with open(f"../../results/preds/exp5_persona_{engine}_{experiment.lower()}_preds.pkl", "wb") as f:
        pickle.dump(persona_to_data, f)

    if ind_persona_result:
        persona_exp_results = get_persona_performance(persona_to_data)

    evaluator = AccuracyEvaluatorWithAmbiguity()
    corr_evaluator = CorrelationEvaluator()
    rmse_evaluator = RMSEEvaluator()
    mae_evaluator = MAEEvaluator()
    auroc_evaluator = AuROCEvaluator()
    ce_evaluator = CEEvaluator()

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    r, p = corr_evaluator.evaluate(all_choice_scores, all_label_indices)
    rmse = rmse_evaluator.evaluate(all_choice_scores, all_label_indices)
    mae = mae_evaluator.evaluate(all_choice_scores, all_label_indices)
    auroc = auroc_evaluator.evaluate(all_choice_scores, all_label_indices)
    ce = ce_evaluator.evaluate(all_choice_scores, all_label_indices)

    print()
    print(f"engine: {engine}")
    print(f"{experiment} Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    print(f"{experiment} Correlation: {r:.4f} (p={p:.4f})")
    print(f"{experiment} RMSE: {rmse:.4f}")
    print(f"{experiment} MAE: {mae:.4f}")
    print(f"{experiment} AuROC: {auroc:.4f}")
    print(f"{experiment} CE: {ce:.4f}")

    if not ind_persona_result:
        return ExperimentResult(acc, conf_interval, r, p, rmse, mae, auroc, ce)  #, ttest_result
    else:
        return ExperimentResult(acc, conf_interval, r, p, rmse, mae, auroc, ce), persona_exp_results

def produce_table2():
    random.seed(42)

    total_sample = (
            sample("insaneProtestors", 5) +
            sample("luxuryLifeHabits", 5) +
            sample("newDealAmerica", 5) +
            sample("nonPoliticalTwitter", 5) +
            sample("politicsPeopleTwitter", 5)
    )

    print(len(total_sample))

    with open(f"../../results/sampled_persona_seed42.pkl", "wb") as f:
        pickle.dump(total_sample, f)

    print("persona saved")

    result = {}

    for engine in tqdm(['text-davinci-002', 'text-davinci-003']):  # , 'gpt-3.5-turbo', 'gpt-4' "text-babbage-001", 'text-curie-001',
        er, persona_er = collect_data(total_sample, "Causal", engine, ind_persona_result=True)
        result[engine] = (er.json, persona_er)

    json.dump(result, open('../../results/exp5_causal_persona_avg_result.json', 'w'), indent=2)

    result = {}
    for engine in tqdm(['text-davinci-002', 'text-davinci-003']):  # 'gpt-3.5-turbo', 'gpt-4' "text-babbage-001", 'text-curie-001',
        er, persona_er = collect_data(total_sample, "Moral", engine, ind_persona_result=True)
        result[engine] = (er.json, persona_er)

    json.dump(result, open('../../results/exp5_moral_persona_avg_result.json', 'w'), indent=2)

def produce_table2_chat():
    random.seed(42)

    total_sample = (
            sample("insaneProtestors", 5) +
            sample("luxuryLifeHabits", 5) +
            sample("newDealAmerica", 5) +
            sample("nonPoliticalTwitter", 5) +
            sample("politicsPeopleTwitter", 5)
    )

    print(len(total_sample))

    with open(f"../../results/sampled_persona_seed42.pkl", "wb") as f:
        pickle.dump(total_sample, f)

    print("persona saved")

    result = {}

    for engine in tqdm(['gpt-4']):  # , 'gpt-3.5-turbo', 'gpt-4' "text-babbage-001", 'text-curie-001',
        er, persona_er = collect_gp4_data(total_sample, "Causal", engine, ind_persona_result=True)
        result[engine] = (er.json, persona_er)

    json.dump(result, open('../../results/exp5_gpt4_causal_persona_avg_result.json', 'w'), indent=2)

    result = {}
    for engine in tqdm(
            ['gpt-4']):  # 'gpt-3.5-turbo', 'gpt-4' "text-babbage-001", 'text-curie-001',
        er, persona_er = collect_gp4_data(total_sample, "Moral", engine, ind_persona_result=True)
        result[engine] = (er.json, persona_er)

    json.dump(result, open('../../results/exp5_gpt4_moral_persona_avg_result.json', 'w'), indent=2)

if __name__ == '__main__':
    # produce_table2()

    # produce_table2_chat()

    random.seed(42)

    total_sample = (
        sample("insaneProtestors", 5) +
        sample("luxuryLifeHabits", 5) +
        sample("newDealAmerica", 5) +
        sample("nonPoliticalTwitter", 5) +
        sample("politicsPeopleTwitter", 5)
    )

    for s in total_sample:
        print(s.category, ":", s.name + s.description)
