import numpy as np
from scipy.stats import zscore
from sklearn.decomposition import PCA
import torch
from typing import Dict, Generator, List, Tuple
from collections import defaultdict

def pca(embeddings: torch.Tensor, num_red_components: int):
    pca = PCA(n_components=num_red_components, svd_solver='full')
    red_embeddings = pca.fit_transform(zscore(np.nan_to_num(embeddings)))
    return torch.tensor(red_embeddings)

def delay_one(mat, d):
    # Delays a matrix by a delay d. Positive d ==> row t has row t-d
    if d == 0:
        return mat.clone() if isinstance(mat, torch.Tensor) else np.copy(mat)

    if isinstance(mat, np.ndarray):
        new_mat = np.zeros_like(mat)
        if d > 0:
            new_mat[d:] = mat[:-d]
        elif d < 0:
            new_mat[:d] = mat[-d:]
        return torch.from_numpy(new_mat)
    elif isinstance(mat, torch.Tensor):
        new_mat = torch.zeros_like(mat)
        if d > 0:
            new_mat[d:] = mat[:-d]
        elif d < 0:
            new_mat[:d] = mat[-d:]
        return new_mat
    else:
        raise TypeError("Input mat must be either a numpy array or a torch tensor.")

def delay_mat(mat, delays):
    # Delays a matrix by a set of delays d.
    # A row t in the returned matrix has the concatenated:
    # row(t-delays[0],t-delays[1]...t-delays[last] )
    if isinstance(mat, np.ndarray):
        new_mat = np.concatenate([delay_one(mat.copy(), d).numpy() for d in delays], axis=-1)
        new_mat = torch.from_numpy(new_mat)
    elif isinstance(mat, torch.Tensor):
        new_mat = torch.cat([delay_one(mat.clone(), d) for d in delays], dim=-1)
    else:
        raise TypeError("Input mat must be either a numpy array or a torch tensor.")
    return new_mat      # (1211, hidden_dim x num_delays)

def k_fold_test_idxs(subjects: List[np.array], subject_idx: int, folds: int = 5, trim: int = 5):
    num_trs = subjects[subject_idx].shape[0]
    fold_size = num_trs // folds

    if folds == 1:
        fold_size = int(num_trs * 0.2)
        test_idxs = [list(range(0, fold_size))]
        return test_idxs

    assert 2 * trim <= fold_size

    test_idxs = []
    for f in range(folds):
        if f == 0:
            start = 0
        else:
            start = trim + fold_size * f
        if f == folds - 1:
            end = num_trs
        else:
            end = fold_size * (f + 1) - trim

        test_idxs.append(list(range(start, end)))

    return test_idxs

def k_fold(
    embeddings: torch.Tensor, subjects: List[np.array], subject_idx: int, folds: int = 5, trim: int = 5
) -> Generator[Tuple[int, np.array, np.array], None, None]:
    """A generator that yields `folds` number of training/test folds while trimming
    off `trim` number of samples at the end of the training folds.

    Note that since all subjects are using the same stimuli, the generator is
    subject-independent as all subjects share the same measurement indices.

    Credits: Adapted from https://github.com/alansun17904/circuit-alignment

    Args:
        folds: The number of folds.
        trim: The number of fMRI measurements to remove from either end of the training and
            test folds.

    Yields:
        A tuple of the index of the current fold, (for both training and testing folds)
        normalized fMRI measurements within fold for each of the points of interest in the
        fold as well as the tokens associated with each of these measurements. The last entry
        in each of these pairs is the mapping between words and token indices.

    Raises:
        AssertionError: If the number of trimmed samples is greater than the total
            number of examples in the test fold.
    """
    num_trs = embeddings.shape[0]
    fold_size = num_trs // folds

    if folds == 1:
        fold_size = int(num_trs * 0.2)
        yield (
            0,
            idx2samples(embeddings, subjects, subject_idx, list(range(0, fold_size))),
            idx2samples(
                embeddings, subjects, subject_idx, list(range(fold_size + trim, num_trs))
            ),
        )
        return

    assert 2 * trim <= fold_size

    for f in range(folds):
        if f == 0:
            start = 0
        else:
            start = trim + fold_size * f
        if f == folds - 1:
            end = num_trs
        else:
            end = fold_size * (f + 1) - trim

        train_st = max(start - trim, 0)
        train_ed = min(end + trim, num_trs)

        test_idxs = list(range(start, end))
        train_idxs = list(range(0, train_st)) + list(
            range(train_ed, num_trs)
        )

        yield f, \
            idx2samples(embeddings, subjects, subject_idx, test_idxs), \
            idx2samples(embeddings, subjects, subject_idx, train_idxs)

def idx2samples(embeddings, subjects, subject_idx, idxs=None):
    measures = subjects[subject_idx][idxs]
    embeddings = embeddings[idxs]

    # Normalize data
    measures = np.nan_to_num(zscore(np.nan_to_num(measures)))
    embeddings = np.nan_to_num(zscore(np.nan_to_num(embeddings)))

    return (
        torch.Tensor(measures).to(torch.float64),
        torch.Tensor(embeddings).to(torch.float64),
        idxs
    )

def story_fold_cv(
    embeddings: Dict[str, torch.Tensor], 
    subjects: List[Dict[str, np.array]], 
    subject_idx: int,
    trim: int = 5
) -> Generator[Tuple[int, Tuple[torch.Tensor, torch.Tensor, List[int]], 
                     Tuple[torch.Tensor, torch.Tensor, List[int]]], None, None]:
    """A generator that yields leave-one-story-out folds for cross-validation.
    
    Each fold uses one story as test set and all other stories as training set,
    respecting story boundaries.
    
    Args:
        embeddings: Dictionary mapping story_name -> tensor of shape (num_trs, hidden_dim*num_delays)
        subjects: List of dictionaries, one per subject. Each dict maps story_name -> array of shape (num_trs, num_voxels)
        subject_idx: Index of the subject to use
        trim: Number of TRs to trim from the edges of each story in training set
    
    Yields:
        A tuple of (fold_index, test_data, train_data) where:
        - fold_index: Current fold number (0 to 10)
        - test_data: Tuple of (fMRI measurements, embeddings, indices) for test story
        - train_data: Tuple of (fMRI measurements, embeddings, indices) for training stories
    """
    # Get the subject's data
    subject_data = subjects[subject_idx]
    
    # Ensure embeddings and subject data have the same stories
    assert set(embeddings.keys()) == set(subject_data.keys()), \
        f"Embeddings and subject {subject_idx} data must have the same story keys"
    
    story_names = sorted(list(embeddings.keys()))  # Sort for reproducibility
    
    # Perform leave-one-story-out cross-validation
    for fold_idx, test_story in enumerate(story_names):
        # Test data: entire test story without trimming
        test_measures = subject_data[test_story]
        test_embeddings = embeddings[test_story]
        num_trs = test_embeddings.shape[0]
        test_indices = list(range(num_trs))
        
        # Normalize test data
        test_measures_norm = np.nan_to_num(zscore(np.nan_to_num(test_measures[5+trim:-trim-5])))
        test_embeddings_norm = np.nan_to_num(zscore(np.nan_to_num(test_embeddings)))
        
        test_data = (
            torch.Tensor(test_measures_norm).to(torch.float64),
            torch.Tensor(test_embeddings_norm).to(torch.float64),
            test_indices
        )
        
        # Training data: all other stories with trimming
        train_measures_list = []
        train_embeddings_list = []
        train_indices = []
        
        for story in story_names:
            if story == test_story:
                continue
                
            story_measures = subject_data[story]
            story_embeddings = embeddings[story]
            
            num_trs = test_embeddings.shape[0]
            train_indices = list(range(num_trs))
            
            # Apply trimming to training stories
            trimmed_measures = story_measures[5+trim:-trim-5]
            trimmed_embeddings = story_embeddings
            trimmed_indices = train_indices[5+trim:-trim-5]

            # Normalize training data
            trimmed_measures = np.nan_to_num(zscore(np.nan_to_num(trimmed_measures)))
            trimmed_embeddings = np.nan_to_num(zscore(np.nan_to_num(trimmed_embeddings)))
            assert trimmed_measures.shape[0] == trimmed_embeddings.shape[0], \
                f"assert {story}, {trimmed_measures.shape[0]}, {trimmed_embeddings.shape[0]}"
            
            train_measures_list.append(trimmed_measures)
            train_embeddings_list.append(trimmed_embeddings)
            train_indices.extend(list(zip(trimmed_indices, [story] * len(trimmed_indices))))
        
        # Concatenate all training stories
        train_measures = np.concatenate(train_measures_list, axis=0)
        train_embeddings = np.concatenate(train_embeddings_list, axis=0)
        
        train_data = (
            torch.Tensor(train_measures).to(torch.float64),
            torch.Tensor(train_embeddings).to(torch.float64),
            train_indices
        )
        
        yield fold_idx, test_data, train_data

def story_fold_test_idxs(
    subjects: List[Dict[str, np.array]], 
    subject_idx: int,
    trim: int = 5,
    num_delays: int = 4,
) -> Tuple[List[List[int]], List[Dict[int, List[int]]]]:
    """
    Returns a list of lists of test indices for leave-one-story-out cross-validation.
    Each inner list contains the local (per-story) indices for the test set of one fold.
    """
    subject_data = subjects[subject_idx]
    story_names = sorted(list(subject_data.keys()))
    test_idxs, stories_delay_idxs = [], []
    for story in story_names:
        num_trs = subject_data[story][5+trim:-trim-5].shape[0]
        local_test_idxs = list(range(5+trim, 5+trim + num_trs))
        test_idxs.append(local_test_idxs)

        if num_delays is not None:
            delay_idxs = defaultdict(list)
            # Fill delay indices for the current story (local indices)
            for tr_idx in local_test_idxs:
                for d in range(num_delays):
                    if tr_idx - d < 5:
                        delay_idxs[tr_idx].append(-1)
                    else:
                        delay_idxs[tr_idx].append(tr_idx - d)
            stories_delay_idxs.append(delay_idxs)
    return test_idxs, stories_delay_idxs

