import os
from pathlib import Path
import pickle
import torch
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer
from captum.attr import LayerGradientXActivation

from datasets import get_dataset
from models import *
from utils import *
from tqdm import tqdm


@measure_performance
def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ### Create output directory
    experiment_dir = create_output_directory(args)
    experiment_dir = experiment_dir / "attribution"
    os.makedirs(experiment_dir, exist_ok=True)

    ### Attribution output folder
    attribution_dir = experiment_dir / "llm"
    os.makedirs(attribution_dir, exist_ok=True)

    # Define file paths
    token_ids_file = attribution_dir / f"token_ids.pkl"
    word_idx_to_tok_idx_file = attribution_dir / f"word_idx_to_tok_idx.pkl"
    token_avg_idxs_file = attribution_dir / f"token_avg_idxs.pkl"

    if args.experiment.verbose:
        print("Outputs will be stored in directory ", str(attribution_dir))

    ### Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model.hugging_face_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ### Get the model
    model = BrainAlignLanguageModel(args.model.hugging_face_model_id, device, args.model.type)
    embedding_layer = model.get_embedding_layer()

    ### Get the dataset
    if args.experiment.verbose:
        print("Step 1: Extract contexts for each word and get word-TR correspondance.")

    dataset_dir = Path("./data") / args.dataset.name
    dataset = get_dataset(
        args.dataset.name,
        dataset_dir,
        tokenizer=tokenizer,
        device=device,
        context_length=args.context_length,
        remove_format_chars=args.dataset.remove_format_chars,
        remove_punc_spacing=args.dataset.remove_punc_spacing,
        verbose=args.experiment.verbose
    )
    attribution_layers = args.attribution.layers
    attribution_subjs = args.attribution.subject_idxs.split(',')
    attribution_subjs = [dataset.subject_idxs.index(subject_name) for subject_name in attribution_subjs]

    ## Define attribution method
    attr_method = LayerGradientXActivation(model.extract_layer_aggregated_representation, embedding_layer)
                        
    # Compute the language model's attributions
    for layer_idx in attribution_layers:
        if args.experiment.verbose:
            print(f"Processing layer {layer_idx}...")

        layer_llm_attrs = []

        if args.experiment.verbose:
            print(f"Tokenizing contexts for layer {layer_idx}...")
        # Check if files exist
        if token_ids_file.exists() and word_idx_to_tok_idx_file.exists() and token_avg_idxs_file.exists():
            if args.experiment.verbose:
                print(f"Loading tokenized contexts for layer {layer_idx} from files...")
            with open(token_ids_file, "rb") as f:
                token_ids_batches = pickle.load(f)
            with open(word_idx_to_tok_idx_file, "rb") as f:
                word_idx_to_tok_idx_batches = pickle.load(f)
            with open(token_avg_idxs_file, "rb") as f:
                token_avg_idxs_batches = pickle.load(f)
        else:
            token_ids_batches = []
            word_idx_to_tok_idx_batches = []
            token_avg_idxs_batches = []
            for context in dataset.contexts:
                token_ids, word_idx_to_tok_idx, token_avg_idxs = model.get_context_token_ids(context, tokenizer)
                token_ids_batches.append(token_ids.squeeze().cpu())
                word_idx_to_tok_idx_batches.append(word_idx_to_tok_idx)
                token_avg_idxs_batches.append(token_avg_idxs)

            # Save tokenized contexts to files
            with open(token_ids_file, "wb") as f:
                pickle.dump(token_ids_batches, f)
            with open(word_idx_to_tok_idx_file, "wb") as f:
                pickle.dump(word_idx_to_tok_idx_batches, f)
            with open(token_avg_idxs_file, "wb") as f:
                pickle.dump(token_avg_idxs_batches, f)

        if args.experiment.verbose:
            print(f"Batching contexts for layer {layer_idx}...")
        # Pad sequences for batch processing
        token_ids_batches = torch.nn.utils.rnn.pad_sequence(token_ids_batches, batch_first=True).cpu()
        dataset_loader = DataLoader(TensorDataset(token_ids_batches), batch_size=args.attribution.batch_size, shuffle=False, num_workers=4)

        if args.experiment.verbose:
            print(f"Starting batch processing for layer {layer_idx}...")
        for batch_idx, token_batch in enumerate(dataset_loader):
            token_batch = token_batch[0].to(device) # (batch_size, num_tokens)
            if args.experiment.verbose:
                print(f"Processing batch {batch_idx + 1} of layer {layer_idx}...")

            model.set_attribution_params(layer_idx, token_avg_idxs_batches[batch_idx * args.attribution.batch_size:(batch_idx + 1) * args.attribution.batch_size])

            # Compute attributions for the batch
            batch_size, seq_len = token_batch.size()
            llm_attrs = torch.zeros((batch_size, seq_len, args.model.hidden_dim))
            for i in range(args.model.hidden_dim):
                llm_attrs[:, :, i] = attr_method.attribute(token_batch, additional_forward_args=i).sum(dim=-1).cpu()

            # Normalize token-level attributions
            llm_attrs = llm_attrs / llm_attrs.abs().sum(dim=1, keepdim=True)

            # Convert token-level attributions to word-level
            if args.experiment.verbose:
                print("Aggregating token-level attributions to word-level...")
            for b_idx, word_to_tok_idx in enumerate(word_idx_to_tok_idx_batches[batch_idx * args.attribution.batch_size:(batch_idx + 1) * args.attribution.batch_size]):
                word_attrs = torch.zeros((len(word_to_tok_idx), args.model.hidden_dim))
                for word_idx, tok_idxs in word_to_tok_idx.items():
                    word_attrs[word_idx] = llm_attrs[b_idx, tok_idxs].mean(dim=0)
                layer_llm_attrs.append(word_attrs.numpy())

            # Free up memory
            del llm_attrs
            torch.cuda.empty_cache()

        # Save attributions for the layer
        with open(attribution_dir / f"layer_{layer_idx}.pkl", "wb") as f:
            pickle.dump(layer_llm_attrs, f)

if __name__ == "__main__":
    main()

