''' 
This module implements the high level interface for computing abstract counterfactuals.
This is mostly tailored for LM Agents, but should be general enough for it to be easy to extend
to other settings. 
'''
from dotenv import load_dotenv
load_dotenv()
import torch
from abstract_cf.sampling_utils import gumbel_max_rejection_sampling, gumbel_max
from abstract_cf.text_generation.utils import importance_counterfactual_posterior_estimate, load_clearml_dataset
import numpy as np 
import plotly.express as px
from argparse import ArgumentParser
from abstract_cf.text_generation.state_interventions import get_intervention_fn
from abstract_cf.text_generation.state import construct_state, State
from transformers import AutoTokenizer
from abstract_cf.text_generation.utils import load_model
import tqdm
import pickle
import time
import os

# Import ClearML for monitoring and artefact saving
from clearml import Task

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'


def compute_abstract_counterfactual(
    factual_state: State,
    interventional_state: State, 
    abstraction: any,    
    abstraction_type: str,
    observed_abstraction: any = None,
    argmax_observation: bool = False,
    n_state_samples: int = 50,
    n_gumbel_samples: int = 100,
    n_posterior_samples: int = 10,
    make_figures: bool = False,
) -> tuple[dict, dict] | dict:

    if observed_abstraction is not None and argmax_observation:
        print('Warning: both observed_abstraction and argmax_observation are set. argmax_observation will be ignored.')
        argmax_observation = False


    # we sample from both factual and interventional states beforehand
    # so we can (if needed) use these to learn the online abstraction
    factual_action_samples = factual_state.call_policy(n_samples=n_state_samples)
    interventional_action_samples = interventional_state.call_policy(n_samples=n_state_samples)

    if abstraction_type == 'online':
        abstraction.fit(factual_action_samples + interventional_action_samples)
        abstraction.print_topics()

    # Step 1: estimate the abstraction distribution for the factual settings
    factual_abstraction_probs = abstraction(factual_action_samples).mean(axis=0)
    
    # plot the distribution of the factual abstraction probabilities
    if make_figures:
        fig_factual_abstraction_probs = px.bar(
            x=abstraction.id_to_label.values(), 
            y=factual_abstraction_probs.cpu().numpy()
        )
    
    # Step 2: derive posterior over exogenous factors given observed abstraction
    # sample some abstraction from the factual distribution
    # this is the abstraction that we will try to 'match' in the counterfactual
    if observed_abstraction is not None:
        observed_abstraction_index = abstraction.label_to_id[observed_abstraction]
    elif argmax_observation:
        # pick the abstraction with the highest probability
        observed_abstraction_index = torch.argmax(factual_abstraction_probs).cpu().item()
        observed_abstraction = abstraction.id_to_label[observed_abstraction_index]
    else:
        observed_abstraction_index = torch.multinomial(factual_abstraction_probs, 1).cpu().item()
        observed_abstraction = abstraction.id_to_label[observed_abstraction_index]

    factual_abstraction_probs = factual_abstraction_probs.cpu().numpy()
    _, G = gumbel_max_rejection_sampling(
        factual_abstraction_probs, 
        observed_abstraction_index, 
        n_samples=n_gumbel_samples
    )

    # Step 3: estimate the abstraction distribution for the intervened upon state 
    all_interventional_abstraction_probs = abstraction(interventional_action_samples)
    interventional_abstraction_probs = all_interventional_abstraction_probs.mean(axis=0) 
    interventional_abstraction_probs = interventional_abstraction_probs.cpu().numpy()

    # plot the distribution of the interventional abstraction probabilities
    if make_figures:
        fig_interventional_abstraction_probs = px.bar(
            x=abstraction.id_to_label.values(), 
            y=interventional_abstraction_probs
        )

    # Step 4: sample from the counterfactual abstraction distribution
    # obtained using the log probabilities from the interventional abstraction distribution
    # and the exogenous factors samples from their posterior given the factual abstraction observation 
    counterfactual_abstraction_index_samples, _ = zip(
        *[
            gumbel_max(np.log(interventional_abstraction_probs), g) 
            for g in G
        ]
    )
    counterfactual_abstraction_index_samples = np.array(counterfactual_abstraction_index_samples)

    # looking at the distribution of the counterfactual abstraction samples
    # this is the 'theoretical' distribution of the counterfactual abstraction. 
    # it is the distribution that, removing numerical imprecisions and small sample sizes we would like to match
    cf_abstraction_samples_dist = np.array([
        np.sum(counterfactual_abstraction_index_samples == i) / len(counterfactual_abstraction_index_samples)
        for i in range(len(abstraction.id_to_label))
    ])

    if make_figures:
        fig_counterfactual_abstraction_samples = px.bar(
            x=abstraction.id_to_label.values(), 
            y=cf_abstraction_samples_dist
        )

    # Step 5: estimate posterior over cf actions given the abstraction samples
    # Step 6: marginalise over Y (this is done in the same step)
    counterfactual_posterior_action = importance_counterfactual_posterior_estimate(
        all_interventional_abstraction_probs,
        counterfactual_abstraction_index_samples
    )
    if make_figures:
        fig_counterfactual_posterior_action = px.bar(
            x=list(range(len(counterfactual_posterior_action))),
            y=counterfactual_posterior_action
        )    

    posterior_action_samples_indices = torch.multinomial(
        torch.tensor(counterfactual_posterior_action), 
        n_posterior_samples
    )
    posterior_action_samples = [
        interventional_action_samples[i] for i in posterior_action_samples_indices
    ]

    samples = {
        'factual_samples': factual_action_samples,
        'observation': {
            'observed_abstraction': observed_abstraction,
            'observed_abstraction_index': observed_abstraction_index,
        },
        'interventional_samples': interventional_action_samples,
        'counterfactual_samples': posterior_action_samples,
    }
    distributions = {
        'factual_abstraction_probs': factual_abstraction_probs,
        'interventional_abstraction_probs': interventional_abstraction_probs,
        'counterfactual_abstraction': cf_abstraction_samples_dist,
        'counterfactual_posterior_action': counterfactual_posterior_action
    }
    if make_figures:
        figures = {
            'factual_abstraction_probs': fig_factual_abstraction_probs,
            'interventional_abstraction_probs': fig_interventional_abstraction_probs,
            'counterfactual_abstraction': fig_counterfactual_abstraction_samples,
            'counterfactual_posterior_action': fig_counterfactual_posterior_action
        }

    acf_data = {
        'samples': samples,
        'observed_abstraction_index': observed_abstraction_index,
        'distributions': distributions,
    }
    if make_figures:
        acf_data['figures'] = figures   

    if abstraction_type == 'online':
        acf_data['online_abstraction_topics'] = abstraction.topics

    return acf_data

def make_parser():
    parser = ArgumentParser()
    parser.add_argument('--factual_model_name', type=str, default='openai-community/gpt2-xl')
    parser.add_argument('--intervention_config', type=str, default='intervention_configs/goemotions_token_replacement.yaml')
    parser.add_argument('--abstraction_type', type=str, default='offline')
    parser.add_argument('--offline_abstraction_checkpoint', type=str, default='model_data/learned_abstractions/profession/checkpoint-5994')
    parser.add_argument('--n_samples_per_generation', type=int, default=150)
    parser.add_argument('--max_length', type=int, default=80)
    parser.add_argument('--n_gumbel_samples', type=int, default=100)
    parser.add_argument('--n_posterior_cf_samples', type=int, default=10)
    parser.add_argument('--observation', type=str, default=None)
    parser.add_argument('--argmax_observation', action='store_true', default=False)
    parser.add_argument('--prompt_tokens', type=int, default=8)
    parser.add_argument('--batch_size', type=int, default=5)
    parser.add_argument('--dataset_id', type=str, default=None, 
                        help="ClearML dataset ID (if provided, dataset_name is ignored)")
    parser.add_argument('--dataset_name', type=str, default='emotion_sample_250',
                        help="ClearML dataset name")
    parser.add_argument('--project_name', type=str, default=os.environ.get('CLEARML_PROJECT_NAME', 'abstract_counterfactuals'),
                        help="ClearML project name where the dataset is stored")
    parser.add_argument('--make_figures', action='store_true', default=False)
    return parser


if __name__=='__main__':
    # Initialise ClearML Task for monitoring
    # Append a timestamp to the task name to ensure uniqueness.
    unique_task_name = f"acf_{int(time.time())}"
    task = Task.init(project_name='abstract_counterfactuals', task_name=unique_task_name)

    parser = make_parser()
    args = parser.parse_args()

    print('loading dataset')
    if args.dataset_id:
        dataset = load_clearml_dataset(dataset_id=args.dataset_id)
        print(f'dataset loaded from ID: {args.dataset_id}')
    else:
        dataset = load_clearml_dataset(dataset_name=args.dataset_name, project_name=args.project_name)
        print(f'dataset loaded by name: {args.dataset_name} from project: {args.project_name}')
    
    print(f'dataset loaded, {len(dataset)} samples')

    print('loading abstraction')
    assert args.abstraction_type in {'offline', 'online'}, 'abstraction type must be either `offline` or `online`' 
    if args.abstraction_type == 'offline':
        from abstract_cf.text_generation.learned_abstraction import LearnedAbstractionPipeline
        abstraction = LearnedAbstractionPipeline.load(args.offline_abstraction_checkpoint, device=device)
        openai_client = None
    else:
        if args.offline_abstraction_checkpoint is not None:
            print('WARNING: online abstraction is being used, but an offline checkpoint was provided. The offline checkpoint will be ignored.')

        from openai import OpenAI
        import os 
        from abstract_cf.text_generation.unsupervised_abstractions import LMAbstraction
        openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
        abstraction = LMAbstraction(openai_client)


    # store the results of the run in a dictionary 
    # this maps sample index -> (sampes -> (factual, interventional, counterfactual), distributions -> (... distributions ...), observation -> id)
    experiment_data = {} 

    factual_model = load_model(args.factual_model_name)
    tokenizer = AutoTokenizer.from_pretrained(args.factual_model_name)
    tokenizer.pad_token = tokenizer.eos_token

    intervention_fn = get_intervention_fn(args.intervention_config)

    print('starting experiment')
    for i, row in tqdm.tqdm(dataset.iterrows(), total=len(dataset)):
        print('constructing factual state')
        factual_state = construct_state(
            prompt=row.text,
            n_prompt_tokens=args.prompt_tokens,
            max_length=args.max_length,
            model=factual_model,
            tokenizer=tokenizer,
            sampling_batch_size=args.batch_size,
            sampling_max_length=args.max_length,
        )
        print('constructing interventional state')

        interventional_state = intervention_fn(factual_state)
        
        print('computing abstract counterfactual')
        with torch.no_grad():
            acf_data = compute_abstract_counterfactual(
                factual_state=factual_state,
                interventional_state=interventional_state,
                abstraction=abstraction,
                abstraction_type=args.abstraction_type,
                observed_abstraction=args.observation,
                argmax_observation=args.argmax_observation,
                n_state_samples=args.n_samples_per_generation,
                n_gumbel_samples=args.n_gumbel_samples,
                n_posterior_samples=args.n_posterior_cf_samples, 
                make_figures=args.make_figures
            )

        experiment_data[i] = {
            **acf_data,
            'prompts':{
                'dataset_text': row.text,
                'factual': factual_state.inputs_text,
                'interventional': interventional_state.inputs_text
            }
        }
         
    print('saving experiment data')
    with open('experiment_data.pkl', 'wb') as f:
        pickle.dump(experiment_data, f)

    print('uploading experiment data to clearml')
    task.upload_artifact(name='experiment_data.pkl', artifact_object='experiment_data.pkl')
    print('finished uploading experiment data to clearml')
    print('goodbye!')

# NOTE if uploading artefact causes a deadlock, try setting `sdk.development.report_use_subprocess: false` in the clearml.conf file
# XXXX
