from pathlib import Path
import math
import torch
import tqdm
from transformers import AutoTokenizer
from captum.attr import LayerGradientXActivation, LayerIntegratedGradients


from brain_alignment import get_delay_idxs, k_fold_test_idxs, story_fold_test_idxs
from datasets import get_dataset
from models import *
from utils import *

@measure_performance
def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ### Set attribution method
    if args.attr_method is not None:
        args.attribution.method = args.attr_method

    ### Create output directory
    experiment_dir = create_output_directory(args)
    output_dir = experiment_dir / "next_word_attrs" / args.attribution.method
    os.makedirs(output_dir, exist_ok=True)
    
    if args.experiment.verbose:
        print("Outputs will be stored in directory ", str(experiment_dir))

    ### Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model.hugging_face_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    # Ensure the tokenizer pads on the right
    tokenizer.padding_side = "right"

    ### Get the model
    model = BrainAlignLanguageModel(args.model.hugging_face_model_id, args.experiment.apply_pca, args.experiment.num_red_components, device, args.model.type, True)

    ### Get the dataset
    if args.experiment.verbose:
        print(f"Starting brain alignment pipeline.")
        print("Step 1: Extract contexts for each word and get word-TR correspondance.")

    dataset_dir = Path(args.data_root_dir) / args.dataset.name
    dataset = get_dataset(
        args.dataset.name,
        dataset_dir,
        tokenizer=tokenizer,
        device=device,
        context_length=args.context_length,
        remove_format_chars=args.dataset.remove_format_chars,
        remove_punc_spacing=args.dataset.remove_punc_spacing,
        verbose=args.experiment.verbose
    )

    ### Get fold test indices and delay indices
    if args.dataset.name == "MothRadioHour":
        fold_test_idxs, story_delay_idxs = story_fold_test_idxs(dataset.subjects, 0, args.experiment.num_tr_trim, args.experiment.num_delays)
        story_names = list(dataset.story_idx_to_name.values())
    else:
        fold_test_idxs = k_fold_test_idxs(dataset.subjects, 0, args.experiment.num_folds, args.experiment.num_tr_trim)
        delay_idxs = get_delay_idxs(dataset.runs_cropped, args.experiment.num_delays)

    for fold_idx, test_idxs in enumerate(fold_test_idxs):
        if args.dataset.name == "MothRadioHour":
            delay_idxs = story_delay_idxs[fold_idx]
            story_name = story_names[fold_idx]
        fold_extended_contexts, next_word_lens = [], []
        for k, test_idx in enumerate(tqdm.tqdm(test_idxs)):
            extended_context, extended_contexts_word_idxs = [], []
            for d in range(args.experiment.num_delays - 1, -1, -1):
                tr_idx = delay_idxs[test_idx][d]
                if tr_idx == -1:
                    continue

                # 1) Gather the 4 contexts in each of the d TRs
                if args.dataset.name == "MothRadioHour":
                    tr_to_word_idxs = dataset.tr_to_word_idxs[story_name]
                    dataset_contexts = dataset.contexts[story_name]
                    contexts = [dataset.contexts[story_name][word_idx] for word_idx in dataset.tr_to_word_idxs[story_name][tr_idx]]
                    contexts_word_idxs = [dataset.context_word_idxs[story_name][word_idx] for word_idx in dataset.tr_to_word_idxs[story_name][tr_idx]]
                else:
                    tr_to_word_idxs = dataset.tr_to_word_idxs
                    dataset_contexts = dataset.contexts
                    contexts = [dataset.contexts[word_idx] for word_idx in dataset.tr_to_word_idxs[tr_idx]]
                    contexts_word_idxs = [dataset.context_word_idxs[word_idx] for word_idx in dataset.tr_to_word_idxs[tr_idx]]
                
                if len(contexts) == 0:
                    log_msg = f"{fold_idx}, {tr_idx} - No contexts found for this TR."
                    print(log_msg)
                    continue

                start_idx = 0
                if extended_contexts_word_idxs == []:
                    # 1) Initialize the first context
                    extended_context = contexts[0].split()
                    extended_contexts_word_idxs = [contexts_word_idxs[0][0], contexts_word_idxs[0][1]]
                    contexts = contexts[1:]
                    start_idx = 1
                for c_idx, context in enumerate(contexts, start=start_idx):
                    context_words = context.split()
                    new_words_count = contexts_word_idxs[c_idx][1] - extended_contexts_word_idxs[1]
                    if new_words_count == 0:
                        continue  # No new words to add
                    elif new_words_count < 0:
                        new_words_count = 1
                    extended_context += context_words[-new_words_count:]
                    extended_contexts_word_idxs[1] = contexts_word_idxs[c_idx][1]
                
                if d == 0:
                    # 2) Identify the "next word" that comes from the subsequent context in your data
                    #    For example, if you want the last word of the *next* TR's first context:
                    #    You need some logic to find the next TR index (e.g., tr_idx+1 if it exists)
                    next_tr_idx = tr_idx + 1
                    if next_tr_idx < len(tr_to_word_idxs):
                        next_context_idxs = tr_to_word_idxs[next_tr_idx]
                        # Pick the *final word* of the next context
                        if len(next_context_idxs) > 0:
                            next_context = dataset_contexts[next_context_idxs[0]].split()
                            next_word = next_context[-1]  # last word of the first context in the next TR
                        else:
                            next_word = None
                    else:
                        next_word = None
                
                    if next_word is not None:
                        next_word_id = tokenizer(next_word, add_special_tokens=False)["input_ids"]
                    else:
                        # Take the last word from the last context
                        last_context_tokens = contexts[-1].split()
                        if len(last_context_tokens) > 0:
                            next_word = last_context_tokens.pop()  # the removed word
                        next_word_id = tokenizer(next_word, add_special_tokens=False)["input_ids"]
                    # IDs of the tokens in the next word
                    next_word_lens.append(len(next_word_id))
                    extended_context.append(next_word)

            # 3) Combine all contexts into a single string
            if len(extended_context) > 0:
                fold_extended_contexts.append(" ".join(extended_context))

        # Now we have a single sequence of words for each TR and compute the corresponding token_ids
        data_loader, _, word_idx_to_tok_idx = dataset.get_context_token_ids(
            args.model.batch_size,
            fold_extended_contexts
        )
        
        # 5) Define an attribution method. We'll apply it to `model.forward_next_word_prediction`,
        #    a function you will write in your model class (see below).
        if args.attribution.method == 'gxi':
            attr_method = LayerGradientXActivation(
                forward_func=model.forward_next_word_prediction,
                layer=model.model.get_input_embeddings()  # or however you get the embedding layer
            )
        else:
            attr_method = LayerIntegratedGradients(
                forward_func=model.forward_next_word_prediction,
                layer=model.model.get_input_embeddings()  # or however you get the embedding layer
            )
        
        # 6) Compute attributions for each batch
        fold_word_attrs = []
        for batch_idx, batch in enumerate(data_loader):
            token_ids, attention_mask = batch
            token_ids = token_ids.to(device) # shape: [batch_size, seq_len]
            attention_mask = attention_mask.to(device) # shape: [batch_size, seq_len]
            
            # The last token in each sequence is the "true next word"
            batch_start = batch_idx * args.model.batch_size
            batch_end = (batch_idx + 1) * args.model.batch_size
            lens_batch = torch.tensor(next_word_lens[batch_start:batch_end], device=device)

            # Compute attributions for the batch
            if args.attribution.method == 'gxi':
                # For LayerGradientXActivation, we need to provide the "target" token IDs
                attrs = attr_method.attribute(
                    token_ids,  # shape [batch_size, seq_len]
                    additional_forward_args=(attention_mask, lens_batch, tokenizer.pad_token_id),
                )# shape [batch_size, seq_len, hidden_dim]
                attrs = torch.norm(attrs, p=2, dim=-1)  # shape: [batch_size, seq_len]
            else:
                attrs = attr_method.attribute(
                    token_ids,  # shape [batch_size, seq_len]
                    baselines=torch.zeros_like(token_ids).to(device),
                    additional_forward_args=(attention_mask, lens_batch, tokenizer.pad_token_id),
                    internal_batch_size=args.model.batch_size,
                    n_steps=20,
                )  # shape [batch_size, seq_len, hidden_dim]
                attrs = attrs.sum(dim=-1)  # shape: [batch_size, seq_len]
            
            # Drop attribution values for the padding tokens and store fold attributions
            attrs_cpu = attrs.cpu()
            mask_cpu = attention_mask.cpu().bool()
            for b, (seq_attrs, seq_mask) in enumerate(zip(attrs_cpu, mask_cpu)):
                seq_attrs = seq_attrs[seq_mask]
                # Convert token-level attributions to word-level attributions
                context_idx = batch_idx * args.model.batch_size + b
                context = fold_extended_contexts[context_idx].split()
                word_attrs = np.zeros((len(context)))
                for word_idx, tok_idxs in word_idx_to_tok_idx[context_idx].items():
                    word_attrs[word_idx] = seq_attrs[tok_idxs].sum(dim=0)
                fold_word_attrs.append(word_attrs[:-1]) # shape: (seq_len,)

        
        # Store attributions for the fold
        with open(output_dir / f"fold_{fold_idx}.pkl", 'wb') as f:
            pickle.dump(fold_word_attrs, f)

    avg_ce = model.fold_loss / model.fold_count
    print(f"Average cross-entropy: {avg_ce:.4f}, Perplexity: {math.exp(avg_ce):.4f}")

if __name__ == "__main__":
    main()

