from pathlib import Path
import torch
from transformers import AutoTokenizer

from datasets import get_dataset
from models import *
from utils import *

@measure_performance
def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ### Create output directory
    experiment_dir = create_output_directory(args)
    
    if args.experiment.verbose:
        print("Outputs will be stored in directory ", str(experiment_dir))

    ### Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model.hugging_face_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ### Get the model
    model = BrainAlignLanguageModel(args.model.hugging_face_model_id, args.experiment.apply_pca, args.experiment.num_red_components, device, args.model.type)

    ### Get the dataset
    if args.experiment.verbose:
        print(f"Starting brain alignment pipeline.")
        print("Step 1: Extract contexts for each word and get word-TR correspondance.")

    dataset_dir = Path(args.data_root_dir) / args.dataset.name
    dataset = get_dataset(
        args.dataset.name,
        dataset_dir,
        tokenizer=tokenizer,
        device=device,
        context_length=args.context_length,
        remove_format_chars=args.dataset.remove_format_chars,
        remove_punc_spacing=args.dataset.remove_punc_spacing,
        verbose=args.experiment.verbose
    )

    # Create batches of token IDs
    if args.experiment.verbose:
        print("     1c - Convert contexts to token IDs.")
    data_loader, token_idxs_to_avg, _ = dataset.get_context_token_ids(args.model.batch_size)


    ### Get the aggregated hidden representations
    if args.experiment.verbose:
        print("Step 2: Compute language model's embeddings.")

    if args.dataset.name == "MothRadioHour":
        for story_idx, story_name in dataset.story_idx_to_name.items():
            aggregate_embeddings = model.extract_model_embeddings(data_loader[story_name], token_idxs_to_avg[story_name], dataset.tr_to_word_idxs[story_name], args.model.batch_size, experiment_dir, args.experiment.verbose, story_idx)
    else:
        aggregate_embeddings = model.extract_model_embeddings(data_loader, token_idxs_to_avg, dataset.tr_to_word_idxs, args.model.batch_size, experiment_dir, args.experiment.verbose)
    print(f"Aggregated embeddings shape: {aggregate_embeddings.shape}")

if __name__ == "__main__":
    main()
