import torch
from model_arithmetic import ModelArithmetic, Evaluation, LLMPrompt, enable_logging
from transformers import set_seed
import pandas as pd
from formulas_toxicity import *
from loguru import logger
import os

enable_logging()

def evaluate_formula(formula, dataset, default_model, formula_file, store_file, store_file_monitor, dataset_file,
                    batch_size=4, temperature=1, top_p=1, top_k=0, model_name_fluency="meta-llama/Llama-2-7b-chat-hf", 
                    dtype=torch.bfloat16, preserve_memory=True, classifier_name="SkolkovoInstitute/roberta_toxicity_classifier", classification_with_input=False, 
                    dtype_faithfulness=torch.bfloat16, finetune_model=False, batch_size_faithfulness=8,
                    reload=False, reload_data=False, max_tokens=32):
    set_seed(42)
    if isinstance(formula, tuple):
        retroactive = [formula[1]]
        formula = formula[0]
    else:
        retroactive = []
    arithmetic = ModelArithmetic(formula, default_model=default_model, retroactive_operators=retroactive)
    arithmetic.save_pretrained("../finetune/arithmetic")
    
    if os.path.isfile(store_file_monitor):
        return
    os.makedirs(os.path.dirname(formula_file), exist_ok=True)
    with open(formula_file, 'w') as outfile:
        outfile.write(str(formula))


    evaluator = Evaluation(arithmetic, dataset=dataset)

    output = evaluator.evaluate(
        store_file=store_file,
        dataset_file=dataset_file,
        batch_size=batch_size,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        model_name_fluency=model_name_fluency,
        dtype=dtype,
        preserve_memory=preserve_memory,
        model_name=classifier_name,
        classification_with_input=classification_with_input,
        dtype_faithfulness=dtype_faithfulness,
        finetune_model=finetune_model,
        batch_size_faithfulness=batch_size_faithfulness,
        reload=reload,
        reload_data=reload_data,
        max_tokens=max_tokens,
        do_perspective=False, # we do this in parallel since this is main bottleneck
        stop_texts=["\n", "Person 1:"]
    )
    
    arithmetic.monitor.store(store_file_monitor)

formulas = []
for model in ["meta-llama/Llama-2-13b-hf", "EleutherAI/Pythia-12b", "mosaicml/mpt-7b"]:
    formulas += [
        main_model(model=model),
        main_model(sentence=positive_sentence, model=model) + 0.0 * main_model(sentence="", model=model),  # 0.0 needed for monitoring
        main_model(sentence=negative_sentence, model=model) + 0.0 * main_model(sentence="", model=model),
        negative_biasing(-0.5, model=model),
        negative_biasing(-0.8, max_=True, model=model),
        negative_biasing(-0.99, max_=True, model=model),
        negative_biasing(-0.96, max_=True, model=model),
        selfdebias(10, model=model),
        classifier(1.0, c_model="finetune/toxicity_classifier", m_model=model),
        combo(0.01, -0.0, -0.96, c_model="finetune/toxicity_classifier", m_model=model),
    ]

formulas += [
    small_model_negative(0, max_=False, variant_small="12b"),
    small_model_negative(-0.5, max_=False, variant_small="12b"),
    small_model_negative(-5/6, max_=True, variant_small="12b"),
    small_model_negative(0, max_=False, variant_big="1.4b", variant_small="1.4b"),
    small_model_negative(0, max_=False, variant_big="410m", variant_small="410m"),
    small_model_negative(0, max_=False, variant_big="1b", variant_small="1b"),
    small_model_negative(-0.5, max_=False, variant_small="1b"),
    small_model_negative(-5/6, max_=True, variant_small="1b"),
    small_model_negative(-0.5, max_=False, variant_small="1b", bad=True),
    small_model_negative(-5/6, max_=True, variant_small="1b", bad=True),
    small_model_negative(-0.5, max_=False, variant_small="410m"),
    small_model_negative(-5/6, max_=True, variant_small="410m"),
    small_model_negative(-0.5, max_=False, variant_small="410m", bad=True),
    small_model_negative(-5/6, max_=True, variant_small="410m", bad=True),
    
]

for model in ["meta-llama/Llama-2-13b-hf", "EleutherAI/Pythia-12b", "mosaicml/mpt-7b"]:
    formulas += [
        combo(0.2, -0.0, -0.8, c_model="finetune/toxicity_classifier", m_model=model),
        combo(0.04, -0.0, -0.96, c_model="finetune/toxicity_classifier", m_model=model),
        combo(0.01, -0.0, -0.99, c_model="finetune/toxicity_classifier", m_model=model),
    ]

dataset = pd.read_csv("data/datasets/pol.csv")
dataset_toxic = dataset[dataset["toxicity"] >= 0.5].reset_index(drop=True)
dataset_toxic = dataset_toxic.sample(frac=1, random_state=42).reset_index(drop=True)
dataset_subset = dataset_toxic[:2000].reset_index(drop=True)
dataset_subset["input"] = dataset_subset["text"].map(lambda x: f"Person 1: {x}\nPerson 2:")

with logger.catch():
    for i, formula in enumerate(formulas):
        if i < 0:
            continue
        if isinstance(formula, tuple):
            first_model = formula[0].runnable_operators()[0].model
        else:
            first_model = formula.runnable_operators()[0].model
            
        if "Pythia" in first_model:
            first_model = "EleutherAI/Pythia-12b"
        if "gpt2-xl" in first_model:
            first_model = "gpt2-xl"
        evaluate_formula(
            formula=formula,
            dataset=dataset_subset,
            default_model=None,
            reload=False,
            reload_data=False,
            formula_file=f"eval/toxicity/{i}/formula.txt",
            store_file=f"eval/toxicity/{i}/evaluation.json",
            store_file_monitor=f"eval/toxicity/{i}/monitor.json",
            dataset_file=f"eval/toxicity/{i}/data.csv",
            batch_size=8,
            temperature=1.0,
            top_p=1.0,
            top_k=0,
            model_name_fluency=first_model,
            dtype=torch.bfloat16,
            preserve_memory=True,
            classifier_name=["SkolkovoInstitute/roberta_toxicity_classifier", "cardiffnlp/twitter-roberta-base-sentiment-latest"],
            classification_with_input=False,
            dtype_faithfulness=torch.bfloat16,
            finetune_model=False,
            batch_size_faithfulness=32,
            max_tokens=32
        )