import argparse
import glob
import numpy
import numpy as np
import pathlib
import pickle
from sklearn.pipeline import make_pipeline
import re
from sklearn.decomposition import PCA
LAYER_NUM_DEFAULT_VALUE = -10000

# Define your base directories
WOLF_VIDEOS_DIR = pathlib.Path("../../dataset/internvl/wolf_videos")  # 17 folders
BOURNE_VIDEOS_DIR = pathlib.Path("../../dataset/internvl/bourne_videos")  # 10 folders
LIFE_VIDEOS_DIR = pathlib.Path("../../dataset/internvl/life_videos")  # 5 folders
HIDDEN_VIDEOS_DIR = pathlib.Path("../../dataset/internvl/figures_videos") # 12 folders

TRAIN_VIDEO_DIRS = [
    ("wolf", WOLF_VIDEOS_DIR),
    ("bourne", BOURNE_VIDEOS_DIR),
]

TEST_VIDEO_DIRS = [
    ("life", LIFE_VIDEOS_DIR),
    #("hidden", HIDDEN_VIDEOS_DIR),
]

def process_video_category(video_dirs, prompt_number, model_name, subject=None, stack_all=False):
    """
    Process all video categories and return stacked hidden states.
    
    Args:
        video_dirs: List of tuples (video_type, video_dir)
        prompt_number: The prompt number to use
        model_name: The model name to use
        stack_all: If True, also return all hidden states stacked together
        
    Returns:
        Dictionary of stacked hidden states for each video type and 
        (if stack_all=True) a combined tensor with all hidden states
    """
    hidden_states_dict = {}
    all_video_hidden_states = []
    
    for video_type, video_dir in video_dirs:
        video_hidden_states = process_video_directory(
            video_dir, prompt_number, model_name, subject
        )
        
        if video_hidden_states is not None:
            hidden_states_dict[video_type] = video_hidden_states
            all_video_hidden_states.append(zscore(video_hidden_states))
            
            # Print shape information
            print(f"{video_type} hidden states shape: {hidden_states_dict[video_type].shape}")
    
    # If requested, stack all hidden states into one big tensor
    if stack_all and all_video_hidden_states:
        combined_hidden_states = np.concatenate(all_video_hidden_states, axis=0)
        print(f"All videos combined shape: {combined_hidden_states.shape}")
        return hidden_states_dict, combined_hidden_states
    
    return hidden_states_dict

def natural_key(x):
    return [int(part) if part.isdigit() else part.lower()
            for part in re.split(r'(\d+)', x.name)]

def process_video_directory(video_dir, prompt_number, model_name, subject=None):
    """
    Process a single video directory containing multiple folders.
    
    Args:
        video_dir: Path to the video directory
        prompt_number: The prompt number to use
        model_name: The model name to use
        
    Returns:
        Stacked tensor of hidden states for the video directory
    """
    all_hidden_states = []
    
    # Get all folders within this video directory

    folder_paths = [f for f in video_dir.iterdir() if f.is_dir()]
    folder_paths = sorted(folder_paths, key=natural_key)
    
    # Process each folder
    for folder_path in folder_paths:
        print(folder_path)
        folder_tensor = process_folder(folder_path, prompt_number, model_name, subject)
        if folder_tensor is not None:
            all_hidden_states.append(folder_tensor)
    
    # Stack all folder tensors into a single large tensor
    if all_hidden_states:
        return np.concatenate(all_hidden_states, axis=0)
    return None

def process_folder(folder_path, prompt_number, model_name, subject=None):
    """
    Process a single folder containing pickle files.
    
    Args:
        folder_path: Path to the folder
        prompt_number: The prompt number to use
        model_name: The model name to use
        
    Returns:
        Stacked tensor of hidden states for the folder
    """
    folder_hidden_states = []
    # Find all pickle files in this folder
    batch_files = glob.glob(str(folder_path.joinpath(f"prompt_{prompt_number}", model_name, "*.pkl")))
    batch_files.sort(key=lambda x: natural_key(pathlib.Path(x)))
    #print(batch_files[0:5])
    # Special case for subject 3 with life05 folder
    if subject == 3 and "life05" in str(folder_path):
        # Limit to the first 383 files if there are more
        if len(batch_files) > 383:
            batch_files = batch_files[:383]
        print(f"Subject 3, life05 folder: Using {len(batch_files)} pkl files")

        # Special case for subject 3 with life05 folder
    if "bourne10" in str(folder_path):
        # Limit to the first 383 files if there are more
        if len(batch_files) > 379:
            batch_files = batch_files[:379]
        print(f"bourne 10 folder: Using {len(batch_files)} pkl files")
    # Process each batch file
    for batch_file in batch_files:
        with open(batch_file, "rb") as f:
            batches = pickle.load(f)
        
        # Process each batch in the file
        for batch in batches:
            hs = batch['language_hidden_states']  # This is your 29xdimension tensor
            folder_hidden_states.append(hs)
            del batch  # Free memory
    
    # Stack all tensors from this folder into a single tensor
    if folder_hidden_states:
        return np.stack(folder_hidden_states)
    return None

def zscore(a, mean=None, std=None, return_info=False):
    # a is [TRs x voxels]
    EPSILON = 0.000000001
    if type(mean) != np.ndarray:
        mean = np.nanmean(a, axis=0)
    if type(std) != np.ndarray:
        std = np.nanstd(a, axis=0)
    if not return_info:
        return (a - np.expand_dims(mean, axis=0)) / (np.expand_dims(std, axis=0) + EPSILON)
    return (a - np.expand_dims(mean, axis=0)) / (np.expand_dims(std, axis=0) + EPSILON), mean, std

def load_brain_data(story, subj):
    if story=='bourn':
        brain_data = np.load('../../sub'+str(subj)+'-bourne-fsaverage6.npy')[:,:4024]
    if story=='wolf':
        brain_data = np.load('../../sub'+str(subj)+'-wolf-fsaverage6.npy')[:,:6989]
    if story=='all':
        temp = []
        #brain_data = np.load('../../sub'+str(subj)+'-bourne-fsaverage6.npy')[:4024,:]
        brain_data = np.load('../../sub'+str(subj)+'-bourne-fsaverage6.npy')[:,:4024]
        temp.append(zscore(brain_data))
        #brain_data = np.load('../../sub'+str(subj)+'-wolf-fsaverage6.npy')[:6993,:]
        brain_data = np.load('../../sub'+str(subj)+'-wolf-fsaverage6.npy')[:,:6993]
        temp.append(zscore(brain_data))
        brain_data = np.hstack(temp)
    return brain_data

def load_test_brain_data(subj):
    #brain_data = zscore(np.load('../../sub'+str(subj)+'-life-fsaverage6.npy'))[:,:2013]
    brain_data = zscore(np.load('../../sub'+str(subj)+'-life-fsaverage6.npy'))[:,:2013]
    return brain_data

def make_delayed(stim, delays, circpad=False):
    """Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] 
    (in samples).
    
    If [circpad], instead of being padded with zeros, [stim] will be circularly shifted.
    """
    nt,ndim = stim.shape
    dstims = []
    for di,d in enumerate(delays):
        dstim = np.zeros((nt, ndim))
        if d<0: ## negative delay
            dstim[:d,:] = stim[-d:,:]
            if circpad:
                dstim[d:,:] = stim[:-d,:]
        elif d>0:
            dstim[d:,:] = stim[:-d,:]
            if circpad:
                dstim[:d,:] = stim[-d:,:]
        else: ## d==0
            dstim = stim.copy()
        dstims.append(dstim)
    return np.hstack(dstims)

def stimulus_features(train_features, test_features):
    delays = list(range(1, 5))
    X_train = [make_delayed(X_train, delays) for X_train in train_features]
    X_test = [make_delayed(X_test, delays) for X_test in test_features]
    return X_train, X_test

def evaluate_model(delRstim, delPstim, Y_train, Y_test):

    # %%
    # Cast to GPU
    Xs_train = [backend.asarray(X_train, dtype="float32") for X_train in delRstim]
    Xs_test = [backend.asarray(X_test, dtype="float32") for X_test in delPstim]

    # %%
    ###############################################################################
    # Precompute the linear kernels
    # -----------------------------
    # We also cast them to float32.

    Ks_train = backend.stack([X_train @ X_train.T for X_train in Xs_train])
    Ks_train = backend.asarray(Ks_train, dtype=backend.float32)
    Y_train = backend.asarray(Y_train, dtype=backend.float32)

    Ks_test = backend.stack(
        [X_test @ X_train.T for X_train, X_test in zip(Xs_train, Xs_test)])
    Ks_test = backend.asarray(Ks_test, dtype=backend.float32)
    Y_test = backend.asarray(Y_test, dtype=backend.float32)

    ###############################################################################
    # Run the solver, using random search
    # -----------------------------------
    # This method should work fine for
    # small number of kernels (< 20). The larger the number of kenels, the larger
    # we need to sample the hyperparameter space (i.e. increasing n_iter).

    ###############################################################################
    # Here we use 100 iterations to have a reasonably fast example (~40 sec).
    # To have a better convergence, we probably need more iterations.
    # Note that there is currently no stopping criterion in this method.
    n_iter = 100

    ###############################################################################
    # Grid of regularization parameters.
    alphas = np.logspace(-10, 10, 21)
    print("alphas: ", alphas)

    ###############################################################################
    # Batch parameters, used to reduce the necessary GPU memory. A larger value
    # will be a bit faster, but the solver might crash if it is out of memory.
    # Optimal values depend on the size of your dataset.
    n_targets_batch = 2000
    n_alphas_batch = 20

    ###############################################################################
    # If ``return_weights == "dual"``, the solver will use more memory.
    # Too mitigate it, you can reduce ``n_targets_batch`` in the refit
    # using ```n_targets_batch_refit``.
    # If you don't need the dual weights, use ``return_weights = None``.
    return_weights = 'dual'
    n_targets_batch_refit = 200

    ###############################################################################
    # Run the solver. For each iteration, it will:
    #
    # - sample kernel weights from a Dirichlet distribution
    # - fit (n_splits * n_alphas * n    _targets) ridge models
    # - compute the scores on the validation set of each split
    # - average the scores over splits
    # - take the maximum over alphas
    # - (only if you ask for the ridge weights) refit using the best alphas per
    #   target and the entire dataset
    # - return for each target the log kernel weights leading to the best CV score
    #   (and the best weights if necessary)
    results = solve_multiple_kernel_ridge_random_search(
        Ks=Ks_train,
        Y=Y_train,
        n_iter=n_iter,
        alphas=alphas,
        n_targets_batch=n_targets_batch,
        return_weights=return_weights,
        n_alphas_batch=n_alphas_batch,
        n_targets_batch_refit=n_targets_batch_refit,
        jitter_alphas=True,
    )

    ###############################################################################
    # Here, we cast the results back to CPU, and to ``numpy`` arrays.
    deltas = backend.to_numpy(results[0])
    dual_weights = backend.to_numpy(results[1])
    cv_scores = backend.to_numpy(results[2])


    ###############################################################################
    # Here, we cast the results back to CPU, and to ``numpy`` arrays.
    deltas = backend.to_numpy(results[0])
    dual_weights = backend.to_numpy(results[1])
    cv_scores = backend.to_numpy(results[2])

    ###############################################################################
    # Compute the primal weights for looking at later
    primal_weights = primal_weights_weighted_kernel_ridge(
        results[1], # dual weights
        results[0], # deltas
        Xs_train
    )
    primal_weights = [backend.to_numpy(pw) for pw in primal_weights]

    ###############################################################################
    # Compute the predictions on the test set
    # ---------------------------------------
    # (requires the dual weights)

    split = True
    score_funcs = [r2_score_split, correlation_score_split]
    score_names = ["r2", "r"]
    scores = {}
    for score_name, score_func in zip(score_names, score_funcs):
        scores[score_name] = backend.to_numpy(predict_and_score_weighted_kernel_ridge(
            Ks_test, dual_weights, deltas, Y_test, split=split,
            n_targets_batch=n_targets_batch, score_func=score_func))
    
    return dual_weights, scores, alphas, deltas, cv_scores

def main():

    # Process training videos
    print("Processing training videos...")
    _,TRAIN_HIDDEN_STATES = process_video_category(TRAIN_VIDEO_DIRS, PROMPT_NUMBER, MODEL_NAME, subject=SUBJECT, stack_all=True)
    TRAIN_HIDDEN_STATES = np.transpose(TRAIN_HIDDEN_STATES, (1,0,2))
    print(f"Final combined tensor shape: {TRAIN_HIDDEN_STATES.shape}")
    
    # Process testing videos
    print("Processing testing videos...")
    _, TEST_HIDDEN_STATES = process_video_category(TEST_VIDEO_DIRS, PROMPT_NUMBER, MODEL_NAME, subject=SUBJECT, stack_all=True)
    TEST_HIDDEN_STATES = np.transpose(TEST_HIDDEN_STATES, (1,0,2))
    print(f"Final combined tensor  test shape: {TEST_HIDDEN_STATES.shape}")

    zRresp, zPresp = load_brain_data('all', SUBJECT).T, load_test_brain_data(SUBJECT).T
    print(zRresp.shape, zPresp.shape)

    # Initialize containers for results from each layer
    all_scores = []
    num_layers = TRAIN_HIDDEN_STATES.shape[0]
    # Determine layer range
    if isinstance(LAYER_NUM, tuple):
    	layer_start, layer_end = LAYER_NUM
    	selected_layers = range(layer_start, layer_end + 1)
    elif LAYER_NUM == LAYER_NUM_DEFAULT_VALUE:
    	selected_layers = range(num_layers)
    else:
    	selected_layers = [LAYER_NUM]

    delays = list(range(1, 6))

    # Process each layer separately
    for layer in np.arange(LAYER_NUM, num_layers):
        print(f"\nTraining encoding model for layer {layer}...")
        
        # Run regression
        from ridge_utils.ridge import bootstrap_ridge

        nboots = 1 # Number of cross-validation runs.
        chunklen = 40 # 
        nchunks = 20
        pca = PCA(n_components=1024)
        train_layer_embeddings = pca.fit_transform(np.nan_to_num(TRAIN_HIDDEN_STATES[layer]))
        test_layer_embeddings = pca.transform(np.nan_to_num(TEST_HIDDEN_STATES[layer]))
        #train_layer_embeddings = np.nan_to_num(TRAIN_HIDDEN_STATES[layer])
        #test_layer_embeddings = np.nan_to_num(TEST_HIDDEN_STATES[layer])
        train_layer_embeddings = make_delayed(train_layer_embeddings, delays)
        test_layer_embeddings = make_delayed(test_layer_embeddings, delays)
        #train_layer_embeddings = make_delayed(np.nan_to_num(TRAIN_HIDDEN_STATES[layer]), delays)
        #test_layer_embeddings = make_delayed(np.nan_to_num(TEST_HIDDEN_STATES[layer]), delays)
            
        alphas = np.logspace(1, 3, 10) # Equally log-spaced alphas between 10 and 1000. The third number is the number of alphas to test.
        all_corrs = []
        wt, corr, alphas, bscorrs, valinds = bootstrap_ridge(train_layer_embeddings, np.nan_to_num(zRresp), test_layer_embeddings, np.nan_to_num(zPresp),
                                                                 alphas, nboots, chunklen, nchunks,
                                                                 singcutoff=1e-10, single_alpha=True)
        pred = np.dot(test_layer_embeddings, wt)

        print ("pred has shape: ", pred.shape)
        voxcorrs = np.zeros((zPresp.shape[1],)) # create zero-filled array to hold correlations
        for vi in range(zPresp.shape[1]):
            voxcorrs[vi] = np.corrcoef(zPresp[:,vi], pred[:,vi])[0,1]
        print (voxcorrs)
        all_scores.append(voxcorrs)
        
        # Save results for this layer
        save_dict = dict(
            desc='mkl',
            subject=SUBJECT,
            layer=layer,
            scores=voxcorrs,
        )
            
        output_file = OUTPUT_DIR.joinpath(f'mkl_layer{layer}_{SUBJECT}.pkl')
        with open(output_file, "wb") as f:
            pickle.dump(save_dict, f)
        print(f"Saved results for layer {layer} to {output_file}")
        
    # Optionally, save combined results
    combined_save_dict = dict(
        desc='mkl_all_layers',
        subject=SUBJECT,
        # features_path=features_file,
        # dual_wt=all_dual_weights,
        scores=all_scores,
        # alphas=all_alphas,
        # deltas=all_deltas,
    )

    combined_path = OUTPUT_DIR.joinpath(f'mkl_combined_all_layers_{SUBJECT}.pkl')
    with open(combined_path, "wb") as f:
        pickle.dump(save_dict, f)
    print(f"Saved combined results for all layers to {combined_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="Hidden States Aligner",
        description="This is a generic script that will train models to predict fMRI readings given model hidden states",
    )

    parser.add_argument(
        "-s",
        "--subject",
        required=True,
        type=int,
        choices=set([1, 2, 3, 5]),
        help="The subject number from the Movie10 dataset whose image embeddings are to be trained",
    )
    parser.add_argument(
        "-d",
        "--base-dir",
        required=False,
        type=pathlib.Path,
        help="The path to the directory where all the models, inputs and outputs will be stored and loaded from",
    )
    parser.add_argument(
        "-m",
        "--model-id",
        type=str,
        required=True,
        help="The model id whose hidden state representations are to be used",
    )
    parser.add_argument(
        "-p",
        "--prompt-number",
        type=int,
        required=True,
        help="The prompt number to use for aligning",
    )

    parser.add_argument(
        "-v",
        "--videos-folder",
        type=str,
        required=True,
        help="The movie folder",
    )
    parser.add_argument(
        "-l",
        "--layer-number",
        type=str,
        required=False,
        default=LAYER_NUM_DEFAULT_VALUE,
        help="The layer numbers to find the alignment for. It can be a number like 0, 1 or a negative number like -1 for the last layer. If not passed, then all the layers will be trained and average correlation will be extracted per layer",
    )
    parser.add_argument(
        "--max-log-10-alpha",
        required=False,
        default=4,
        type=int,
        help="Maximum value of log10 alpha to consider",
    )
    parser.add_argument(
        "--num-alphas",
        required=False,
        default=60,
        type=int,
        help="Number of alpha values to sample",
    )
    parser.add_argument(
        "--telegram-bot-token",
        required=False,
        default="",
        type=str,
        help="Telegram Bot token to use for tqdm",
    )
    parser.add_argument(
        "--telegram-chat-id",
        required=False,
        default=0,
        type=int,
        help="Telegram Chat ID to send tqdm updates to",
    )

    args = parser.parse_args()

    SUBJECT: int = args.subject
    BASE_DIR: pathlib.Path = args.base_dir
    MODEL_ID: str = args.model_id
    PROMPT_NUMBER: int = args.prompt_number
    MAX_LOG_10_ALPHA: int = args.max_log_10_alpha
    NUM_ALPHAS: int = args.num_alphas
    TELEGRAM_BOT_TOKEN: str = args.telegram_bot_token
    TELEGRAM_CHAT_ID: int = args.telegram_chat_id
    TO_USE_TELEGRAM: bool = TELEGRAM_BOT_TOKEN != "" and TELEGRAM_CHAT_ID != 0
    LAYER_NUM: str = args.layer_number
    VIDEOS_DIR: pathlib.Path = args.videos_folder

    MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

    if TO_USE_TELEGRAM:
        from tqdm.contrib.telegram import tqdm
    else:
        from tqdm.auto import tqdm
    if "-" in str(args.layer_number):
    	parts = str(args.layer_number).split("-")
    	LAYER_NUM = int(parts[0])
    else:
    	LAYER_NUM = int(args.layer_number)
    OUTPUT_DIR = BASE_DIR.joinpath(
        "final_scores",
        f"prompt_{PROMPT_NUMBER}",
        MODEL_NAME,
        f"subj{SUBJECT:02}",
        #f"layer_{'all' if LAYER_NUM == LAYER_NUM_DEFAULT_VALUE else LAYER_NUM}"
        f"layer_all"
    )

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    main()
