import os
import numpy as np
from pathlib import Path
import torch
from transformers import AutoTokenizer

from brain_alignment import train_brain_encoding_model
from datasets import get_dataset
from models import *
from utils import *

@measure_performance
def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ### Create output directory
    experiment_dir = create_output_directory(args)
    
    if args.experiment.verbose:
        print("Outputs will be stored in directory ", str(experiment_dir))

    ### Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model.hugging_face_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    ### Get the dataset
    if args.experiment.verbose:
        print(f"Starting brain alignment pipeline.")
        print("Step 1: Extract contexts for each word and get word-TR correspondance.")

    dataset_dir = Path(args.data_root_dir) / args.dataset.name
    dataset = get_dataset(
        args.dataset.name,
        dataset_dir,
        tokenizer=tokenizer,
        device=device,
        context_length=args.context_length,
        remove_format_chars=args.dataset.remove_format_chars,
        remove_punc_spacing=args.dataset.remove_punc_spacing,
        verbose=args.experiment.verbose
    )

    ### Load model embeddings
    if args.experiment.verbose:
        print("Step 2: Load pre-computed embeggings.")
    if args.dataset.name == "MothRadioHour":
        model_embeddings = {}
        for story_idx, story_name in dataset.story_idx_to_name.items():
            story_embeddings = np.load(experiment_dir / f'aggregated_embeddings_story_{story_idx}.npy', allow_pickle=True)
            model_embeddings[story_idx] = story_embeddings
            print(f"Model embeddings shape for story {story_name}: {story_embeddings.shape}")
    else:
        model_embeddings = np.load(experiment_dir / 'aggregated_embeddings.npy', allow_pickle=True) # (n_layers, n_tr, hidden_dim)
        print(f"Model embeddings shape: {model_embeddings.shape}")

    ### Train brain encoding model and predict brain alignment
    if args.experiment.verbose:
        print("Step 3: Train brain activity predictor model and compute correlation scores.")
    train_brain_encoding_model(model_embeddings, args, dataset, experiment_dir, device)

if __name__ == "__main__":
    main()
