import os
import numpy as np
from pathlib import Path
import torch
from transformers import AutoTokenizer

from brain_alignment import train_brain_encoding_model
from datasets import get_dataset
from models import *
from utils import *


@measure_performance
def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ### Create output directory
    experiment_dir = create_output_directory(args)
    
    if args.experiment.verbose:
        print("Outputs will be stored in directory ", str(experiment_dir))

    ### Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model.hugging_face_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ### Get the model
    model = BrainAlignLanguageModel(args.model.hugging_face_model_id, args.experiment.apply_pca, args.experiment.num_red_components, device, args.model.type)

    ### Get the dataset
    if args.experiment.verbose:
        print(f"Starting brain alignment pipeline.")
        print("Step 1: Extract contexts for each word and get word-TR correspondance.")

    dataset_dir = Path(args.data_root_dir) / args.dataset.name
    dataset = get_dataset(
        args.dataset.name,
        dataset_dir,
        tokenizer=tokenizer,
        device=device,
        context_length=args.context_length,
        remove_format_chars=args.dataset.remove_format_chars,
        remove_punc_spacing=args.dataset.remove_punc_spacing,
        verbose=args.experiment.verbose
    )

    ### Create batches of token IDs
    if args.experiment.verbose:
        print("     1c - Convert contexts to token IDs.")
    data_loader, token_idxs_to_avg, _ = dataset.get_context_token_ids(args.model.batch_size)

    ### Get the aggregated hidden representations
    if args.experiment.verbose:
        print("Step 2: Compute language model's embeddings.")
    if args.dataset.name == "MothRadioHour":
        model_embeddings = {}
        for story_idx, story_name in dataset.story_idx_to_name.items():
            model_embeddings[story_idx] = model.extract_model_embeddings(data_loader[story_name], token_idxs_to_avg[story_name], dataset.tr_to_word_idxs[story_name], args.model.batch_size, experiment_dir, args.experiment.verbose, story_idx)
    else:
        model_embeddings = model.extract_model_embeddings(data_loader, token_idxs_to_avg, dataset.tr_to_word_idxs, args.model.batch_size, experiment_dir, args.experiment.verbose)

    ### Train brain encoding model and predict brain alignment
    if args.experiment.verbose:
        print("Step 3: Train brain activity predictor model and compute correlation scores.")
    train_brain_encoding_model(model_embeddings, args, dataset, experiment_dir, device)

if __name__ == "__main__":
    main()
