import pandas as pd 
import time
import clearml
from argparse import ArgumentParser
from abstract_cf.text_generation.utils import get_samples_from_acf_task
from ravfogel_lm_counterfactuals.utils import (
    load_model, get_continuation, get_counterfactual_output 
)
import tqdm 
from transformers import AutoTokenizer
import torch
import pickle


device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'


def make_parser():
    parser = ArgumentParser()
    # this is necessary to know which samples to use for the sentiment tracking experiment
    parser.add_argument(
        '--acf_task_id',
        type=str,
        required=True,
        help='the clearml task id of the abstract counterfactuals experiment to use for the sentiment tracking experiment'
    )
    parser.add_argument(
        '--model_name', 
        type=str, 
        default='openai-community/gpt2-xl'
    )
    parser.add_argument(
        '--max_new_tokens',
        type=int,
        default=80,
    )
    parser.add_argument(
        '--num_counterfactuals',
        type=int,
        default=30,
    )
    return parser


if __name__ == '__main__':
    unique_task_name = f'tlcf_sentiment_tracking_token_level_{time.time()}'
    task = clearml.Task.init(project_name='abstract_counterfactuals', task_name=unique_task_name) 

    parser = make_parser()
    args = parser.parse_args()

    print('fetching samples from acf task')
    dataset = get_samples_from_acf_task(args.acf_task_id)
    print('sample shape:', dataset.shape)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name,   # tokenizer is shared between the two models  
        model_max_length=512, 
        padding_side="right", 
        use_fast=False,
        trust_remote_code=False
    )
    tokenizer.pad_token = tokenizer.eos_token

    model = load_model(args.model_name)
    # model.config.pad_token_id = tokenizer.eos_token_id

    # map  index -> {factual: {tokens: [], text: []}, counterfactuals: [{tokens: [], text: []}]}
    experiment_data = {}
    for i, row in tqdm.tqdm(dataset.iterrows(), total=len(dataset)):
        factual_inputs = row.factual 
        interventional_inputs = row.interventional
        
        factual_samples_tokens, factual_samples = get_continuation(
            model, 
            tokenizer, 
            factual_inputs, 
            max_new_tokens=args.max_new_tokens,
            return_only_continuation=True,
            num_beams=1,
            do_sample=True,
            token_healing=False,
        )

        counterfactuals = []
        for l in range(args.num_counterfactuals):
            count_tokens, count_text = get_counterfactual_output(
                model, 
                model, 
                tokenizer, 
                factual_inputs, 
                interventional_inputs, 
                factual_samples, 
                args.max_new_tokens
            )
            # Ensure that any tensor is moved to CPU before being added to experiment_data
            if isinstance(count_tokens, torch.Tensor):
                count_tokens = count_tokens.cpu()
            counterfactuals.append({"tokens": count_tokens, "text": count_text})

        # factual_tokens = torch.concat([tokenized_factual['input_ids'][0],factual_samples_tokens[0]]) 
        experiment_data[i] = {
            'factual_prompt': {
                # 'tokens': tokenized_factual['input_ids'][0],
                'text': factual_inputs,
            },
            'interventional_prompt': {
                # 'tokens': tokenized_interventional['input_ids'][0],
                'text': interventional_inputs
            },
            'factual': {
                # 'tokens': factual_tokens,
                'text': factual_inputs + factual_samples
            },
            'counterfactuals': counterfactuals
        }
        
    with open('experiment_data.pkl', 'wb') as f:
        pickle.dump(experiment_data, f)
    task.upload_artifact(name='experiment_data.pkl', artifact_object='experiment_data.pkl') 
