from prompts_builder import generate_prompt_elements_StaA, generate_prompt_elements_REVB
import strings as S
import utils_models as UM
import utils_distributions as UD
import numpy as np
import constants as C
import config as Cf
import torch.backends as backends
import pandas as pd
import os
import gc
import torch
import metrics
import csv
from tqdm import tqdm
from scenarios import get_scenarios

scenarios = get_scenarios(include_stat=True, include_preference=True, include_coins=True, include_dice=True, include_choice=True)

MODELS = Cf.MODEL_LIST # pick models from MODEL_LIST, or use your own model id. The model must be compatible with the transformers library.

for model in MODELS:
    assert model in Cf.MODEL_TO_GPU_LIST, f"Model {model} missing in MODEL_TO_GPUS"

if __name__ == "__main__":
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5"

    if not os.path.exists(Cf.RESULT_DIR):
        os.makedirs(Cf.RESULT_DIR)

    for model_id in MODELS:
        # Load model
        model, tokenizer, model_short_name = UM.load_model(
            model_id=model_id,
            device_map="mps" if backends.mps.is_available() else "auto",
            trust_remote_code=True,
        )

        print(f"Status: loaded {model_short_name}")

        for scenario_name, configs_to_test in scenarios.items():
            print(f"Status: Starting eval of tasks in scenario {scenario_name}")

            scenario_results = []
            for config_scenario in tqdm(configs_to_test):
                if scenario_name == "stat":
                    config_scenario[S.scenario_names.STATS]["tokenizer"] = tokenizer
                true_distribution = UD.generateTrueDistribution(config=config_scenario)
                is_valid, max_likelihood = UD.find_maximum_likelihoo(true_distribution)

                if not is_valid:
                    if Cf.VERBOSE:
                        print(f"The provided scenario {config_scenario[S.prompt_parameter.SCENARIO]} subscenario {config_scenario[S.prompt_parameter.SUBSCENARIO]} does not have a unique maximum likelihood. Only evaluating the implicit distribution.")
                    implicit_prompt, outcomes = generate_prompt_elements_REVB(config_scenario)

                if Cf.VERBOSE:
                    print(f"Status: scenario {config_scenario[S.prompt_parameter.SCENARIO]} subscenario {config_scenario[S.prompt_parameter.SUBSCENARIO]} loaded")

                ########### Computing REVB
                if Cf.VERBOSE:
                    print(f"Status: Computing Implicit Prediction")

                implicit_prompt, outcomes = generate_prompt_elements_REVB(config_scenario)

                batch_prompts = []
                batch_next_token = []
                batch_outcome = []

                base_prompt_tokenized = tokenizer.tokenize(implicit_prompt)

                for outcome in outcomes:
                    new_prompt = implicit_prompt + f" " + outcome
                    new_prompt_tokenized = tokenizer.tokenize(new_prompt)
                    for i_tokens in range(len(base_prompt_tokenized), len(new_prompt_tokenized)):
                        batch_prompts.append(tokenizer.decode(tokenizer.convert_tokens_to_ids(new_prompt_tokenized[:i_tokens])))
                        batch_next_token.append(tokenizer.convert_tokens_to_ids(new_prompt_tokenized[i_tokens]))
                        batch_outcome.append(outcome)

                # Run inference
                predictions = UM.batch_inference(
                    model,
                    tokenizer,
                    batch_prompts,
                    batch_size=1,
                    new_token_count=1,
                    top_p=None,
                    top_k=None,
                    temperature=None,
                )

                # Compute Implicit Prediction
                proba_outcome = {outcome: 0 for outcome in outcomes}

                for i in range(len(batch_prompts)):
                    log_proba = predictions[i][0][batch_next_token[i]]
                    proba_outcome[batch_outcome[i]] += log_proba

                model_implicit = {outcome: float(np.exp(proba_outcome[outcome])) for outcome in outcomes}
                if scenario_name == "dice" and not "phi-4" in model_short_name.lower():
                    # Adjust for overlapping tokens in dice scenario when digits are tokenized individually
                    max_val = max([int(val) for val in model_implicit.keys()])
                    if max_val >= 10:
                        for val in range(10, max_val + 1):
                            str_val = str(val)
                            if str_val in model_implicit:
                                base_val = str(val // 10)
                                if base_val in model_implicit:
                                    model_implicit[base_val] -= model_implicit[str_val]

                if Cf.VERBOSE:
                    print(f"Prompt: {implicit_prompt}")
                    print(f"Computed Implicit Prediction: ", {key: round(value, 2) for key, value in model_implicit.items()})
                    print(f"True distribution: ", {key: round(frac.to_float(), 2) for key, frac in true_distribution.items()})

                implicitPred_likelihood_max_likelihood = model_implicit[max_likelihood]
                implicitPred_favored_outcome = None
                implicitPred_likelihood_favored_outcome = None
                for outcome in model_implicit.keys():
                    if (implicitPred_favored_outcome is None) or (model_implicit[outcome] > implicitPred_likelihood_favored_outcome):
                        implicitPred_favored_outcome = outcome
                        implicitPred_likelihood_favored_outcome = model_implicit[outcome]

                # Store results for Implicit Prediction
                sumimplicit = sum(model_implicit.values())
                model_implicit_normalized = {key: value / sumimplicit for key, value in model_implicit.items()}
                if scenario_name == "stat":
                    del config_scenario[S.scenario_names.STATS]["tokenizer"]
                result_row_dict = {
                    "implicitPred_prompt": implicit_prompt,
                    "implicitPred_likelihoods": {key: round(value, 5) for key, value in model_implicit.items()},
                    "implicitPred_true_distribution": {key: round(frac.to_float(), 5) for key, frac in true_distribution.items()},
                    "implicitPred_favored_outcome": implicitPred_favored_outcome,
                    "implicitPred_likelihood_favored_outcome": implicitPred_likelihood_favored_outcome,
                    "implicitPred_max_likelihood": max_likelihood,
                    "implicitPred_likelihood_max_likelihood": implicitPred_likelihood_max_likelihood,
                    "implicitPred_scenario_settings": config_scenario,
                    "implicitPred_tail_sum": 1 - sumimplicit,
                    "implicit_l1_distance": metrics.l1_distance(true_distribution, model_implicit_normalized),
                    "implicit_chebyshev_distance": metrics.chebyshev_distance(true_distribution, model_implicit_normalized),
                    "implicit_symmetric_kl_divergence": metrics.symmetric_kl_divergence(true_distribution, model_implicit_normalized),
                    "implicitPred_is_favored_outcome_correct": implicitPred_favored_outcome == max_likelihood,
                }

                ########### Computing Explicit
                if not is_valid:
                    result_row_dict["explicitAnswer_prompt"] = ""
                    result_row_dict["explicitAnswer_likelihoods"] = {}
                    result_row_dict["explicitAnswer_correct_answer"] = ""
                    result_row_dict["explicitAnswer_likelihood_correct_answer"] = np.nan
                    result_row_dict["explicitAnswer_favored_answer"] = ""
                    result_row_dict["explicitAnswer_likelihood_favored_answer"] = np.nan
                    result_row_dict["explicitAnswer_scenario_settings"] = {}
                    result_row_dict["explicitAnswer_probability_correct_answer"] = np.nan
                    result_row_dict["explicitAnswer_is_favored_outcome_correct"] = None
                else:
                    if Cf.VERBOSE:
                        print(f"Status: Computing Explicit Answer")
                    base_prompt, question, choices, values, correct_answer = generate_prompt_elements_StaA(config_scenario)

                    chat_template = UM.build_staA_prompt(
                        prompt=base_prompt,
                        question=question,
                        choices=choices,
                        values=values,
                        tokenizer=tokenizer,
                    )

                    chat_template_patched = UM.patch_chat_template(model_name=model_short_name, tokenizer=tokenizer, prompt=chat_template)

                    chat_template_with_first_tokens, _ = UM.generate_first_tokens(model=model, tokenizer=tokenizer, prompt=chat_template_patched)

                    batch_prompts = [chat_template_with_first_tokens]

                    predictions = UM.batch_inference(
                        model,
                        tokenizer,
                        batch_prompts,
                        batch_size=1,
                        new_token_count=1,
                        top_p=None,
                        top_k=None,
                        temperature=None,
                    )

                    log_proba_answers = [predictions[0][0][tokenizer.convert_tokens_to_ids(l)] for l in C.prompt_parameter.LETTERS[: C.prompt_parameter.QCM_OPTIONS]]

                    proba_answers = np.array([float(np.exp(log_proba_answers[i])) for i in range(len(log_proba_answers))])
                    proba_answers = proba_answers / np.sum(proba_answers)
                    proba_answers = {C.prompt_parameter.LETTERS[i]: proba_answers[i] for i in range(len(log_proba_answers))}

                    sta_prob_on_correct = proba_answers[correct_answer]

                    if Cf.VERBOSE:
                        print(f"Computed Explicit")
                        print("===" * 30)
                        print(f"Summary scenario {config_scenario[S.prompt_parameter.SCENARIO]} subscenario {config_scenario[S.prompt_parameter.SUBSCENARIO]}")
                        print(f"Prompt: {chat_template_with_first_tokens}")
                    max_answer = max(proba_answers, key=proba_answers.get)

                    if Cf.VERBOSE:
                        print(f"explicitAnswer favored option {max_answer} with likelihood {proba_answers[max_answer]:.2f}")
                        print(f"explicitAnswer correct answer {correct_answer} with likelihood {sta_prob_on_correct:.2f}")
                        print(f"implicitPred favored outcome {implicitPred_favored_outcome} with likelihood {implicitPred_likelihood_favored_outcome:.2f}")
                        print(f"implicit True max likelihood ({max_likelihood}) with {implicitPred_likelihood_max_likelihood:.2f} (true: {true_distribution[max_likelihood].to_float():.2f})")

                    result_row_dict["explicitAnswer_prompt"] = chat_template_with_first_tokens
                    result_row_dict["explicitAnswer_likelihoods"] = {key: round(value, 5) for key, value in proba_answers.items()}
                    result_row_dict["explicitAnswer_correct_answer"] = correct_answer
                    result_row_dict["explicitAnswer_likelihood_correct_answer"] = sta_prob_on_correct
                    result_row_dict["explicitAnswer_favored_answer"] = max_answer
                    result_row_dict["explicitAnswer_likelihood_favored_answer"] = proba_answers[max_answer]
                    result_row_dict["explicitAnswer_scenario_settings"] = config_scenario
                    result_row_dict["explicitAnswer_probability_correct_answer"] = sta_prob_on_correct
                    result_row_dict["explicitAnswer_is_favored_outcome_correct"] = max_answer == correct_answer

                    if Cf.VERBOSE:
                        print(f"===" * 30)
                        print()

                scenario_results.append(result_row_dict)
            df_results = pd.DataFrame(scenario_results)
            df_results.to_csv(f"{Cf.RESULT_DIR}/results_{model_short_name.split('/')[-1]}_{scenario_name}.csv", index=True, index_label="index")
            print(f"Results saved to {Cf.RESULT_DIR}/results_{model_short_name.split('/')[-1]}_{scenario_name}.csv")

        # clear gpu memory
        del model
        del tokenizer
        gc.collect()
        torch.cuda.empty_cache()

    exit(0)
