import argparse
import glob
import numpy
import numpy as np
import pathlib
import pickle
from sklearn.pipeline import make_pipeline
import re
from sklearn.decomposition import PCA


LAYER_NUM_DEFAULT_VALUE = -10000

# Define your base directories
WOLF_VIDEOS_DIR = pathlib.Path("../../dataset/wolf_videos")  # 17 folders
BOURNE_VIDEOS_DIR = pathlib.Path("../../dataset/bourne_videos")  # 10 folders
LIFE_VIDEOS_DIR = pathlib.Path("../../dataset/life_videos")  # 5 folders
HIDDEN_VIDEOS_DIR = pathlib.Path("../../dataset/figures_videos") # 12 folders

TRAIN_VIDEO_DIRS = [
    ("wolf", WOLF_VIDEOS_DIR),
    ("bourne", BOURNE_VIDEOS_DIR),
]

TEST_VIDEO_DIRS = [
    ("life", LIFE_VIDEOS_DIR),
    #("hidden", HIDDEN_VIDEOS_DIR),
]

def natural_key(x):
    return [int(part) if part.isdigit() else part.lower()
            for part in re.split(r'(\d+)', x.name)]

def process_video_category(video_dirs, prompt_number, model_name, subject=None, stack_all=False):
    """
    Process all video categories and return stacked hidden states.
    
    Args:
        video_dirs: List of tuples (video_type, video_dir)
        prompt_number: The prompt number to use
        model_name: The model name to use
        stack_all: If True, also return all hidden states stacked together
        
    Returns:
        Dictionary of stacked hidden states for each video type and 
        (if stack_all=True) a combined tensor with all hidden states
    """
    hidden_states_dict = {}
    all_video_hidden_states = []
    
    for video_type, video_dir in video_dirs:
        video_hidden_states = process_video_directory(
            video_dir, prompt_number, model_name, subject
        )
        
        if video_hidden_states is not None:
            hidden_states_dict[video_type] = video_hidden_states
            all_video_hidden_states.append(zscore(video_hidden_states))
            
            # Print shape information
            print(f"{video_type} hidden states shape: {hidden_states_dict[video_type].shape}")
    
    # If requested, stack all hidden states into one big tensor
    if stack_all and all_video_hidden_states:
        combined_hidden_states = np.concatenate(all_video_hidden_states, axis=0)
        print(f"All videos combined shape: {combined_hidden_states.shape}")
        return hidden_states_dict, combined_hidden_states
    
    return hidden_states_dict

def natural_key(x):
    return [int(part) if part.isdigit() else part.lower()
            for part in re.split(r'(\d+)', x.name)]

def process_video_directory(video_dir, prompt_number, model_name, subject=None):
    """
    Process a single video directory containing multiple folders.
    
    Args:
        video_dir: Path to the video directory
        prompt_number: The prompt number to use
        model_name: The model name to use
        
    Returns:
        Stacked tensor of hidden states for the video directory
    """
    all_hidden_states = []
    
    # Get all folders within this video directory

    folder_paths = [f for f in video_dir.iterdir() if f.is_dir()]
    folder_paths = sorted(folder_paths, key=natural_key)
    
    # Process each folder
    for folder_path in folder_paths:
        print(folder_path)
        folder_tensor = process_folder(folder_path, prompt_number, model_name, subject)
        if folder_tensor is not None:
            all_hidden_states.append(folder_tensor)
    
    # Stack all folder tensors into a single large tensor
    if all_hidden_states:
        return np.concatenate(all_hidden_states, axis=0)
    return None

def process_folder(folder_path, prompt_number, model_name, subject=None):
    """
    Process a single folder containing pickle files.
    
    Args:
        folder_path: Path to the folder
        prompt_number: The prompt number to use
        model_name: The model name to use
        
    Returns:
        Stacked tensor of hidden states for the folder
    """
    folder_hidden_states = []
    
    # Find all pickle files in this folder
    batch_files = glob.glob(str(folder_path.joinpath(f"prompt_{prompt_number}", model_name, "*.pkl")))
    batch_files.sort(key=lambda x: natural_key(pathlib.Path(x)))
    #print(batch_files[0:5])
    # Special case for subject 3 with life05 folder
    if subject == 3 and "life05" in str(folder_path):
        # Limit to the first 383 files if there are more
        if len(batch_files) > 383:
            batch_files = batch_files[:383]
        print(f"Subject 3, life05 folder: Using {len(batch_files)} pkl files")

        # Special case for subject 3 with life05 folder
    if "bourne10" in str(folder_path):
        # Limit to the first 383 files if there are more
        if len(batch_files) > 379:
            batch_files = batch_files[:379]
        print(f"bourne 10 folder: Using {len(batch_files)} pkl files")
    # Process each batch file
    for batch_file in batch_files:
        with open(batch_file, "rb") as f:
            batches = pickle.load(f)
        
        # Process each batch in the file
        for batch in batches:
            hs = batch['language_hidden_states']  # This is your 29xdimension tensor
            folder_hidden_states.append(hs)
            del batch  # Free memory
    
    # Stack all tensors from this folder into a single tensor
    if folder_hidden_states:
        return np.stack(folder_hidden_states)
    return None

def zscore(a, mean=None, std=None, return_info=False):
    # a is [TRs x voxels]
    EPSILON = 0.000000001
    if type(mean) != np.ndarray:
        mean = np.nanmean(a, axis=0)
    if type(std) != np.ndarray:
        std = np.nanstd(a, axis=0)
    if not return_info:
        return (a - np.expand_dims(mean, axis=0)) / (np.expand_dims(std, axis=0) + EPSILON)
    return (a - np.expand_dims(mean, axis=0)) / (np.expand_dims(std, axis=0) + EPSILON), mean, std

def load_brain_data(story, subj):
    if story=='bourn':
        brain_data = np.load('../../../movie10_braindata/sub'+str(subj)+'-bourne-fsaverage6.npy')[:,:4024]
    if story=='wolf':
        brain_data = np.load('../../../movie10_braindata/sub'+str(subj)+'-wolf-fsaverage6.npy')[:,:6989]
    if story=='all':
        temp = []
        brain_data = np.load('../../../movie10_braindata/sub'+str(subj)+'-bourne-fsaverage6.npy')[:,:4024]
        temp.append(zscore(brain_data))
        brain_data = np.load('../../../movie10_braindata/sub'+str(subj)+'-wolf-fsaverage6.npy')[:,:6993]
        temp.append(zscore(brain_data))
        brain_data = np.hstack(temp)
    return brain_data

def load_test_brain_data(subj):
    brain_data = zscore(np.load('../../../movie10_braindata/sub'+str(subj)+'-life-fsaverage6.npy'))[:,:2013]
    return brain_data

def make_delayed(stim, delays, circpad=False):
    """Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] 
    (in samples).
    
    If [circpad], instead of being padded with zeros, [stim] will be circularly shifted.
    """
    nt,ndim = stim.shape
    dstims = []
    for di,d in enumerate(delays):
        dstim = np.zeros((nt, ndim))
        if d<0: ## negative delay
            dstim[:d,:] = stim[-d:,:]
            if circpad:
                dstim[d:,:] = stim[:-d,:]
        elif d>0:
            dstim[d:,:] = stim[:-d,:]
            if circpad:
                dstim[:d,:] = stim[-d:,:]
        else: ## d==0
            dstim = stim.copy()
        dstims.append(dstim)
    return np.hstack(dstims)

def stimulus_features(train_features, test_features):
    delays = list(range(1, 5))
    X_train = [make_delayed(X_train, delays) for X_train in train_features]
    X_test = [make_delayed(X_test, delays) for X_test in test_features]
    return X_train, X_test

def evaluate_model(delRstim, delPstim, Y_train, Y_test):

    # %%
    # Cast to GPU
    Xs_train = [backend.asarray(X_train, dtype="float32") for X_train in delRstim]
    Xs_test = [backend.asarray(X_test, dtype="float32") for X_test in delPstim]

    # %%
    ###############################################################################
    # Precompute the linear kernels
    # -----------------------------
    # We also cast them to float32.

    Ks_train = backend.stack([X_train @ X_train.T for X_train in Xs_train])
    Ks_train = backend.asarray(Ks_train, dtype=backend.float32)
    Y_train = backend.asarray(Y_train, dtype=backend.float32)

    Ks_test = backend.stack(
        [X_test @ X_train.T for X_train, X_test in zip(Xs_train, Xs_test)])
    Ks_test = backend.asarray(Ks_test, dtype=backend.float32)
    Y_test = backend.asarray(Y_test, dtype=backend.float32)

    print("Ks_train.shape:", Ks_train.shape)  # Expect (n_kernels, n_samples, n_samples)
    print("Y_train.shape:", Y_train.shape)    # Expect (n_samples, n_targets)
    print(f"X_test shape: {Ks_test.shape}, Y_test shape: {Y_test.shape}")

    ###############################################################################
    # Run the solver, using random search
    # -----------------------------------
    # This method should work fine for
    # small number of kernels (< 20). The larger the number of kenels, the larger
    # we need to sample the hyperparameter space (i.e. increasing n_iter).

    ###############################################################################
    # Here we use 100 iterations to have a reasonably fast example (~40 sec).
    # To have a better convergence, we probably need more iterations.
    # Note that there is currently no stopping criterion in this method.
    n_iter = 1000

    ###############################################################################
    # Grid of regularization parameters.
    alphas = np.logspace(-10, 10, 21)
    print("alphas: ", alphas)

    ###############################################################################
    # Batch parameters, used to reduce the necessary GPU memory. A larger value
    # will be a bit faster, but the solver might crash if it is out of memory.
    # Optimal values depend on the size of your dataset.
    n_targets_batch = 1000
    n_alphas_batch = 20

    ###############################################################################
    # If ``return_weights == "dual"``, the solver will use more memory.
    # Too mitigate it, you can reduce ``n_targets_batch`` in the refit
    # using ```n_targets_batch_refit``.
    # If you don't need the dual weights, use ``return_weights = None``.
    return_weights = 'dual'
    n_targets_batch_refit = 200

    ###############################################################################
    # Run the solver. For each iteration, it will:
    #
    # - sample kernel weights from a Dirichlet distribution
    # - fit (n_splits * n_alphas * n    _targets) ridge models
    # - compute the scores on the validation set of each split
    # - average the scores over splits
    # - take the maximum over alphas
    # - (only if you ask for the ridge weights) refit using the best alphas per
    #   target and the entire dataset
    # - return for each target the log kernel weights leading to the best CV score
    #   (and the best weights if necessary)
    results = solve_multiple_kernel_ridge_random_search(
        Ks=Ks_train,
        Y=Y_train,
        n_iter=n_iter,
        alphas=alphas,
        n_targets_batch=n_targets_batch,
        return_weights=return_weights,
        n_alphas_batch=n_alphas_batch,
        n_targets_batch_refit=n_targets_batch_refit,
        jitter_alphas=True,
    )

    ###############################################################################
    # Here, we cast the results back to CPU, and to ``numpy`` arrays.
    deltas = backend.to_numpy(results[0])
    dual_weights = backend.to_numpy(results[1])
    cv_scores = backend.to_numpy(results[2])


    ###############################################################################
    # Here, we cast the results back to CPU, and to ``numpy`` arrays.
    deltas = backend.to_numpy(results[0])
    dual_weights = backend.to_numpy(results[1])
    cv_scores = backend.to_numpy(results[2])

    ###############################################################################
    # Compute the primal weights for looking at later
    primal_weights = primal_weights_weighted_kernel_ridge(
        results[1], # dual weights
        results[0], # deltas
        Xs_train
    )
    primal_weights = [backend.to_numpy(pw) for pw in primal_weights]

    ###############################################################################
    # Compute the predictions on the test set
    # ---------------------------------------
    # (requires the dual weights)

    split = True
    score_funcs = [r2_score_split, correlation_score_split]
    score_names = ["r2", "r"]
    scores = {}
    for score_name, score_func in zip(score_names, score_funcs):
        scores[score_name] = backend.to_numpy(predict_and_score_weighted_kernel_ridge(
            Ks_test, dual_weights, deltas, Y_test, split=split,
            n_targets_batch=n_targets_batch, score_func=score_func))
    
    return dual_weights, scores, alphas, deltas, cv_scores

def main():

    from itertools import combinations

    PROMPTS = [1, 4, 8, 15, 20, 30, 33, 38, 43, 48, 52, 67, 71]
    PROMPT_PAIRS = list(combinations(PROMPTS, 2))
    BEST_LAYER = 10  # Or use a dictionary to specify per prompt if needed
    delays = list(range(1, 6))

    zRresp, zPresp = load_brain_data('all', SUBJECT).T, load_test_brain_data(SUBJECT).T
    print(zRresp.shape, zPresp.shape)

    for prompt_pair in PROMPT_PAIRS:
        print(f"Running prompt pair: {prompt_pair}")
        prompt_nums = list(prompt_pair)
        best_layers = [BEST_LAYER, BEST_LAYER]  # Replace with per-prompt values if you have them

        TRAIN_HIDDEN_EMBEDDINGS = []
        TEST_HIDDEN_EMBEDDINGS = []
        for prompt_num, best_layer in zip(prompt_nums, best_layers):
            print(f"Prompt {prompt_num}, Layer {best_layer}")
            _, TRAIN_HIDDEN_STATES = process_video_category(TRAIN_VIDEO_DIRS, prompt_num, MODEL_NAME,subject=SUBJECT, stack_all=True)
            _, TEST_HIDDEN_STATES = process_video_category(TEST_VIDEO_DIRS, prompt_num, MODEL_NAME,subject=SUBJECT, stack_all=True)
            TRAIN_HIDDEN_STATES = np.transpose(TRAIN_HIDDEN_STATES, (1, 0, 2))  # [Layers, T, d]
            TEST_HIDDEN_STATES = np.transpose(TEST_HIDDEN_STATES, (1, 0, 2))  # [Layers, T, d]
            layer_embeddings_train = TRAIN_HIDDEN_STATES[best_layer]  # [T, d]
            layer_embeddings_test = TEST_HIDDEN_STATES[best_layer]  # [T, d]
            pca = PCA(n_components=1024)
            train_layer_embeddings = pca.fit_transform(np.nan_to_num(layer_embeddings_train))
            test_layer_embeddings = pca.transform(np.nan_to_num(layer_embeddings_test))
            TRAIN_HIDDEN_EMBEDDINGS.append(np.nan_to_num(train_layer_embeddings))
            TEST_HIDDEN_EMBEDDINGS.append(np.nan_to_num(test_layer_embeddings))
        TRAIN_HIDDEN_EMBEDDINGS = np.concatenate(TRAIN_HIDDEN_EMBEDDINGS, axis=1)  # [T, 13*d]
        TEST_HIDDEN_EMBEDDINGS = np.concatenate(TEST_HIDDEN_EMBEDDINGS, axis=1)  # [T, 13*d]
        print("Final TRAIN embedding shape:", TRAIN_HIDDEN_EMBEDDINGS[0].shape)
        print("Final TEST embedding shape:", TRAIN_HIDDEN_EMBEDDINGS[0].shape)

        # Run regression
        from ridge_utils.ridge import bootstrap_ridge

        nboots = 1 # Number of cross-validation runs.
        chunklen = 40 # 
        nchunks = 20
        train_layer_embeddings = make_delayed(np.nan_to_num(TRAIN_HIDDEN_EMBEDDINGS), delays)
        test_layer_embeddings = make_delayed(np.nan_to_num(TEST_HIDDEN_EMBEDDINGS), delays)
            
        alphas = np.logspace(1, 3, 10) # Equally log-spaced alphas between 10 and 1000. The third number is the number of alphas to test.
        all_corrs = []
        wt, corr, alphas, bscorrs, valinds = bootstrap_ridge(train_layer_embeddings, np.nan_to_num(zRresp), test_layer_embeddings, np.nan_to_num(zPresp),
                                                                 alphas, nboots, chunklen, nchunks,
                                                                 singcutoff=1e-10, single_alpha=True)
        pred = np.dot(test_layer_embeddings, wt)

        print ("pred has shape: ", pred.shape)
        voxcorrs = np.zeros((zPresp.shape[1],)) # create zero-filled array to hold correlations
        for vi in range(zPresp.shape[1]):
            voxcorrs[vi] = np.corrcoef(zPresp[:,vi], pred[:,vi])[0,1]
        print (voxcorrs)
        all_corrs.append(voxcorrs)
        
        pair_str = f"{prompt_nums[0]}_{prompt_nums[1]}"

        # Save results for this prompt pair
        save_dict = dict(
            desc='mkl',
            subject=SUBJECT,
            prompts=prompt_nums,
            scores=all_corrs,
        )

        output_file = OUTPUT_DIR.joinpath(f'mkl_prompts_{pair_str}_subj{SUBJECT:02}.pkl')
        with open(output_file, "wb") as f:
            pickle.dump(save_dict, f)
        print(f"Saved results for prompts {pair_str} to {output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="Hidden States Aligner",
        description="This is a generic script that will train models to predict fMRI readings given model hidden states",
    )

    parser.add_argument(
        "-s",
        "--subject",
        required=True,
        type=int,
        choices=set([1, 2, 3, 5]),
        help="The subject number from the Movie10 dataset whose image embeddings are to be trained",
    )
    parser.add_argument(
        "-d",
        "--base-dir",
        required=False,
        type=pathlib.Path,
        help="The path to the directory where all the models, inputs and outputs will be stored and loaded from",
    )
    parser.add_argument(
        "-m",
        "--model-id",
        type=str,
        required=True,
        help="The model id whose hidden state representations are to be used",
    )

    parser.add_argument(
        "-v",
        "--videos-folder",
        type=str,
        required=True,
        help="The movie folder",
    )
    parser.add_argument(
        "-l",
        "--layer-number",
        type=int,
        required=False,
        default=LAYER_NUM_DEFAULT_VALUE,
        help="The layer numbers to find the alignment for. It can be a number like 0, 1 or a negative number like -1 for the last layer. If not passed, then all the layers will be trained and average correlation will be extracted per layer",
    )
    parser.add_argument(
        "--max-log-10-alpha",
        required=False,
        default=4,
        type=int,
        help="Maximum value of log10 alpha to consider",
    )
    parser.add_argument(
        "--num-alphas",
        required=False,
        default=60,
        type=int,
        help="Number of alpha values to sample",
    )
    parser.add_argument(
        "--telegram-bot-token",
        required=False,
        default="",
        type=str,
        help="Telegram Bot token to use for tqdm",
    )
    parser.add_argument(
        "--telegram-chat-id",
        required=False,
        default=0,
        type=int,
        help="Telegram Chat ID to send tqdm updates to",
    )

    args = parser.parse_args()

    SUBJECT: int = args.subject
    BASE_DIR: pathlib.Path = args.base_dir
    MODEL_ID: str = args.model_id
    MAX_LOG_10_ALPHA: int = args.max_log_10_alpha
    NUM_ALPHAS: int = args.num_alphas
    TELEGRAM_BOT_TOKEN: str = args.telegram_bot_token
    TELEGRAM_CHAT_ID: int = args.telegram_chat_id
    TO_USE_TELEGRAM: bool = TELEGRAM_BOT_TOKEN != "" and TELEGRAM_CHAT_ID != 0
    LAYER_NUM: int = args.layer_number
    VIDEOS_DIR: pathlib.Path = args.videos_folder

    MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

    if TO_USE_TELEGRAM:
        from tqdm.contrib.telegram import tqdm
    else:
        from tqdm.auto import tqdm

    OUTPUT_DIR = BASE_DIR.joinpath(
        "final_scores",
        "concat_prompts",
        MODEL_NAME,
        f"subj{SUBJECT:02}"
        )

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    main()
