import numpy as np
from scipy.stats import zscore
import torch
import os
from pathlib import Path

from .alignment_utils import delay_mat, k_fold, story_fold_cv
from .ridge_regression import RidgeCV
import time
import tracemalloc
import gzip, torch, pickle

def save_ridge(path: Path, W: torch.Tensor, compress_lvl: int = 4):
    """Store ridge weights in float16 and gzip‑compress on the fly."""
    path = path.with_suffix(".pt.gz")
    with gzip.open(path, "wb", compresslevel=compress_lvl) as f:
        torch.save(W.half(), f)
    
def get_subject_idxs_names(dataset, subject_idx):
    """Get subject indices and names based on the provided subject_idx."""
    if subject_idx == '':
        subject_names = dataset.subject_idxs
        subject_idxs = list(range(len(subject_names)))
    else:
        subject_names = subject_idx.split(',')
        subject_idxs = [dataset.subject_idxs.index(subject_name) for subject_name in subject_names]
    
    return subject_idxs, subject_names

def single_story_encoding_model(aggregated_embeddings, args, dataset, experiment_dir, device):
    subject_idxs, subject_names = get_subject_idxs_names(dataset, args.experiment.subject_idx)
    weights_dir = experiment_dir / 'ridge_weights'
    os.makedirs(weights_dir, exist_ok=True)

    for i, subject_idx in enumerate(subject_idxs):
        ridge = RidgeCV(n_splits=5, device=device)

        layer_scores = []
        for layer_idx, layer_embeddings in enumerate(aggregated_embeddings):
            # Delay embeddings to account for the delay between stimulus and BOLD response and get delayed embeddings to TR dictionary
            if args.experiment.verbose:
                print(f"     Layer {layer_idx}: 3b - Apply delay to account for delay between the stimulus and BOLD response.")
            layer_embeddings = delay_mat(layer_embeddings, np.arange(1, args.experiment.num_delays+1)) # (num_trs, hidden_dim*num_delays)

            # Normalize data by run
            if args.experiment.normalize_data_by_run:
                layer_embeddings = torch.cat([zscore(torch.nan_to_num(layer_embeddings[dataset.runs_cropped==i])) for i in range(1,5)])

            # Generate CV folds
            if args.experiment.verbose:
                print(f"     Layer {layer_idx}: 3c - Perform CV and compute mean correlation across folds.")
            folds_generator = k_fold(layer_embeddings, dataset.subjects, subject_idx, args.experiment.num_folds, args.experiment.num_tr_trim)

            for fold_idx, (test_fmri, test_feats, test_idxs), (train_fmri, train_feats, _) in folds_generator:
                # Train ridge regression model
                _, _ = ridge.fit(train_feats, train_fmri)

                # Compute correlation scores
                if fold_idx == 0:
                    fold_scores = ridge.r_score(test_feats, test_fmri)
                else:
                    fold_scores += ridge.r_score(test_feats, test_fmri)
                
                # Save ridge regression weights
                save_ridge(weights_dir / f"ridge_weights_subject_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}", ridge.W)

            # Compute mean correlation across folds
            layer_scores.append(fold_scores / args.experiment.num_folds)

        # Save correlation scores
        corr_scores_file = 'correlation_scores' + '_' + subject_names[i] + '.npy'
        np.save(experiment_dir / corr_scores_file, layer_scores)

def multi_story_encoding_model(aggregated_embeddings, args, dataset, experiment_dir, device):
    subject_idxs, subject_names = get_subject_idxs_names(dataset, args.experiment.subject_idx)
    weights_dir = experiment_dir / 'ridge_weights'
    os.makedirs(weights_dir, exist_ok=True)

    for i, subject_idx in enumerate(subject_idxs):
        ridge = RidgeCV(n_splits=5, device=device)
        layer_scores = []
        for layer_idx in range(args.model.num_layers):
            processed_embeddings = {}
            for story_idx, story_name in dataset.story_idx_to_name.items():
                story_embeddings = aggregated_embeddings[story_idx][layer_idx]  # (n_trs, hidden_dim)

                # Delay embeddings to account for the delay between stimulus and BOLD response and get delayed embeddings to TR dictionary
                if args.experiment.verbose:
                    print(f"     Layer {layer_idx}: 3b - Apply delay to account for delay between the stimulus and BOLD response.")
                story_embeddings = delay_mat(story_embeddings, np.arange(1, args.experiment.num_delays+1)) # (num_trs, hidden_dim*num_delays)

                # Normalize data by run
                if args.experiment.normalize_data_by_run:
                    story_embeddings = zscore(torch.nan_to_num(story_embeddings[5+args.experiment.num_tr_trim:-args.experiment.num_tr_trim]))
                processed_embeddings[f"story_{story_idx}"] = story_embeddings

            # Generate CV folds
            if args.experiment.verbose:
                print(f"     Layer {layer_idx}: 3c - Perform CV and compute mean correlation across folds.")
            folds_generator = story_fold_cv(processed_embeddings, dataset.subjects, subject_idx, args.experiment.num_tr_trim)

            for fold_idx, (test_fmri, test_feats, _), (train_fmri, train_feats, _) in folds_generator:
                # Train ridge regression model
                _, _ = ridge.fit(train_feats, train_fmri)

                # Compute correlation scores
                if fold_idx == 0:
                    fold_scores = ridge.r_score(test_feats, test_fmri)
                else:
                    fold_scores += ridge.r_score(test_feats, test_fmri)
                
                # Save ridge regression weights
                save_ridge(weights_dir / f"ridge_weights_subject_{subject_idx}_layer_{layer_idx}_fold_{fold_idx}", ridge.W)

            # Compute mean correlation across folds
            layer_scores.append(fold_scores / args.experiment.num_folds)

        # Save correlation scores
        corr_scores_file = 'correlation_scores' + '_' + subject_names[i] + '.npy'
        np.save(experiment_dir / corr_scores_file, layer_scores)

def train_brain_encoding_model(aggregated_embeddings, args, dataset, experiment_dir, device):
    if args.dataset.name == "HarryPotter":
        # For Harry Potter, we have only one story, so we can use the single story encoding model
        single_story_encoding_model(aggregated_embeddings, args, dataset, experiment_dir, device)
    elif args.dataset.name == "MothRadioHour":
        # For Moth Radio Hour, we have multiple stories, so we use the multi-story encoding model
        multi_story_encoding_model(aggregated_embeddings, args, dataset, experiment_dir, device)
