import pandas as pd 
import time
import clearml
from argparse import ArgumentParser
from ravfogel_lm_counterfactuals.utils import (
    load_model, get_continuation, get_counterfactual_output, get_counterfactual_model
)
import tqdm 
from transformers import AutoTokenizer
import torch
import pickle
from abstract_cf.text_generation.utils import get_samples_from_acf_task
import yaml


device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'


def make_parser():
    parser = ArgumentParser()
    parser.add_argument(
        '--factual_model_name', 
        type=str, 
        default='openai-community/gpt2-xl'
    )
    parser.add_argument(
        '--intervention_config', 
        type=str, 
        default='intervention_configs/gender_steering_gpt2-xl.yaml'
    )
    parser.add_argument(
        '--max_new_tokens',
        type=int,
        default=80,
    )
    parser.add_argument(
        '--num_counterfactuals',
        type=int,
        default=10,
    )
    parser.add_argument(
        '--acf_task_id',
        type=str,
        default='2d7e908328fd43e695ec68904f86f548',
    )
    return parser


if __name__ == '__main__':
    unique_task_name = f'tlcf_bios_gender_steering_{time.time()}'
    task = clearml.Task.init(project_name='abstract_counterfactuals', task_name=unique_task_name) 

    parser = make_parser()
    args = parser.parse_args()

    print('fetching samples from acf task')
    dataset = get_samples_from_acf_task(args.acf_task_id)
    print('sample shape:', dataset.shape)

    tokenizer = AutoTokenizer.from_pretrained(
        args.factual_model_name,   # tokenizer is shared between the two models  
        model_max_length=512, 
        padding_side="right", 
        use_fast=False,
        trust_remote_code=False
    )
    tokenizer.pad_token = tokenizer.eos_token

    factual_model = load_model(args.factual_model_name)
    factual_model.config.pad_token_id = tokenizer.eos_token_id

    with open(args.intervention_config, 'r') as f:
        intervention_config = yaml.safe_load(f)
    counterfactual_model = get_counterfactual_model(
        intervention_config['intervention_kwargs']['model_name']
    )
    counterfactual_model.config.pad_token_id = tokenizer.eos_token_id

    # map  index -> {factual: {tokens: [], text: []}, counterfactuals: [{tokens: [], text: []}]}
    experiment_data = {}
    for i, row in tqdm.tqdm(dataset.iterrows(), total=len(dataset)):
        # NOTE: these will be identical for gender steering task, becasue we do not intervene on the prompt 
        factual_inputs = row.factual 
        interventional_inputs = row.interventional
        factual_samples_tokens, factual_samples = get_continuation(
            factual_model, 
            tokenizer, 
            factual_inputs, 
            max_new_tokens=args.max_new_tokens,
            return_only_continuation=True,
            num_beams=1,
            do_sample=True,
            token_healing=False,
        )

        counterfactuals = []
        for l in range(args.num_counterfactuals):
            count_tokens, count_text = get_counterfactual_output(
                counterfactual_model, 
                factual_model, 
                tokenizer, 
                factual_inputs, 
                interventional_inputs, 
                factual_samples, 
                args.max_new_tokens,
            )
            counterfactuals.append({'text': count_text})

        experiment_data[i] = {
            'factual_prompt': {'text': factual_inputs},
            'interventional_prompt': {'text': interventional_inputs},
            'factual': {'text': factual_inputs + factual_samples},
            'counterfactuals': counterfactuals
        }
        
    with open('experiment_data.pkl', 'wb') as f:
        pickle.dump(experiment_data, f)
    task.upload_artifact(name='experiment_data.pkl', artifact_object='experiment_data.pkl') 
    