'''
This script tries to find interesting case studies for the BiosBias dataset.
We look for samples whose abstraction distribution changes significantly when we apply the 'gender-steering' intervention.
To do this, we measure the kl divergence between factual and interventionla abstraction distributions. 
'''

from dotenv import load_dotenv
load_dotenv()

from abstract_cf.bios.utils import load_dataset, load_learned_abstraction
from transformers.modeling_outputs import SequenceClassifierOutput
from abstract_cf.bios.profession_classifier import ProfessionClassifier
import plotly.express as px
import torch.nn.functional as F
import torch
from transformers import AutoTokenizer
import ravfogel_lm_counterfactuals
from ravfogel_lm_counterfactuals.utils import get_counterfactual_model, load_model
from abstract_cf.bios.utils import sample_from_model
from abstract_cf.bios.profession_classifier import analyze_career_distribution
import tqdm 
import clearml
import os 
from argparse import ArgumentParser
from clearml import Task
import pandas as pd

CLEARML_PROJECT_NAME = os.environ['CLEARML_PROJECT_NAME']
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'


parser = ArgumentParser()
parser.add_argument('--factual_model_name', type=str, default='openai-community/gpt2-xl')
parser.add_argument('--intervention_type', type=str, default='mimic_gender_gpt2_instruct')
parser.add_argument('--learned_abstraction_checkpoint', type=str, default='model_data/learned_abstractions/profession/checkpoint-7991')
parser.add_argument('--n_samples_per_generation', type=int, default=30)   # smaller to speed up, we just want a rough estimate
parser.add_argument('--max_length', type=int, default=100)
parser.add_argument('--prompt_tokens', type=int, default=8)
parser.add_argument('--dataset_sample_size', type=int, default=75)



def get_abstraction_kl_divergence(
    inputs: torch.Tensor,
    factual_model,
    counterfactual_model,
    tokenizer,
    profession_classifier,
    n_samples_per_generation: int = 50,
    max_length: int = 100,
):
    factual_text, factual_token_ids = sample_from_model(
        factual_model, tokenizer, inputs, n_samples=n_samples_per_generation, max_length=max_length
    )
    cf_text, cf_token_ids = sample_from_model(
        counterfactual_model, tokenizer, inputs, n_samples=n_samples_per_generation, max_length=max_length
    )
    factual_df = analyze_career_distribution(factual_text, profession_classifier)
    counterfactual_df = analyze_career_distribution(cf_text, profession_classifier)

    # compute the kl divergence between the two distributions 
    factual_probs = torch.tensor(factual_df.probability.values)
    counterfactual_probs = torch.tensor(counterfactual_df.probability.values)

    kl = (factual_probs * (factual_probs / counterfactual_probs).log()).sum()
    return kl 


if __name__ == '__main__':
    args = parser.parse_args()

    Task.init(project_name=CLEARML_PROJECT_NAME, task_name="Find Bios Case Studies")
    log = Task.current_task().get_logger()

    datasets, id_to_label = load_dataset()
    dataset_sample = datasets['dev'].sample(args.dataset_sample_size)

    profession_classifier = ProfessionClassifier(
        id_to_label=id_to_label,
        device=device,
        learned_abstraction_checkpoint=args.learned_abstraction_checkpoint
    )

    factual_model = load_model(args.factual_model_name)
    counterfactual_model = get_counterfactual_model(args.intervention_type)
    tokenizer = AutoTokenizer.from_pretrained(
        args.factual_model_name, 
        model_max_length=512, 
        padding_side="right", 
        use_fast=False,
        trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    factual_model.config.pad_token_id = tokenizer.pad_token_id
    counterfactual_model.config.pad_token_id = tokenizer.pad_token_id

    kl_divergences = []

    for i, biography in tqdm.tqdm(enumerate(dataset_sample.hard_text), total=len(dataset_sample)):
        inputs = tokenizer(
            biography, 
            return_tensors='pt', 
            truncation=True, 
            max_length=args.prompt_tokens
        ).to(device)

        kl = get_abstraction_kl_divergence(
            inputs=inputs,
            factual_model=factual_model,
            counterfactual_model=counterfactual_model,
            tokenizer=tokenizer,
            profession_classifier=profession_classifier,
            n_samples_per_generation=args.n_samples_per_generation,
            max_length=args.max_length,
        )
        kl_divergences.append(kl.item())
        log.report_scalar(
            title="KL", 
            series="kl_value", 
            iteration=i, 
            value=kl.item()
        )
    results_df = pd.DataFrame(
        {
            "index": dataset_sample.index,
            "text": dataset_sample.hard_text,
            "kl": kl_divergences
        }
    )
    results_df.to_csv("kl_results.csv", index=False)
    log.current_logger().report_table(
        title="table pd", 
        series="PD with index", 
        iteration=0, 
        table_plot=results_df
    )
