import numpy as np
from regression.feature_loader import test_feature_loader
from regression.regression_utils import load_pkl
from regression.losses import brain_score_1 as brain_score
from regression.load_meg_targets import load_meg_targets
from regression.session_story_configs import subject_test_configs
import torch
from regression.regression_closed_form import block_gpu_multiply
import os

def block_resample_generator(
    ts_list: list[np.ndarray],
    n_samples: int = 100,
    time_window: int = 20,
    replacement: bool = False,
):
    """
    Generator version of block_resample. Yields one resampled tuple per iteration,
    so you never hold all samples at once.

    Usage:
        for sample_idx, (r1, r2, ...) in enumerate(
                block_resample_generator([ts1, ts2, ...], 100, 20)):
            # process r1, r2, ... here
    """
    # 1) common time‐axis length check
    T = ts_list[0].shape[0]
    if any(ts.shape[0] != T for ts in ts_list):
        raise ValueError("All time series must share the same length on axis 0")

    # 2) how many full blocks fit?
    n_blocks = T // time_window
    if n_blocks < 1:
        raise ValueError("time_window must be <= length of each series")
    M = n_blocks * time_window

    # 3) slice & block each series
    blocks_list  = []
    tail_shapes  = []
    dtypes       = []
    for ts in ts_list:
        ts_cut      = ts[:M]
        tail_shape  = ts.shape[1:]
        block_shape = (n_blocks, time_window) + tail_shape
        blocks_list.append(ts_cut.reshape(block_shape))
        tail_shapes.append(tail_shape)
        dtypes.append(ts.dtype)

    # 4) yield one sample at a time
    for _ in range(n_samples):
        # draw block‐indices once
        if replacement:
            idxs = np.random.randint(0, n_blocks, size=n_blocks)
        else:
            idxs = np.random.permutation(n_blocks)

        # reconstruct each series
        out = []
        for blocks, tail_shape, dt in zip(blocks_list, tail_shapes, dtypes):
            sampled = blocks[idxs]                       # (n_blocks, time_window, ...)
            arr     = sampled.reshape((M,) + tail_shape).astype(dt)
            out.append(arr)

        yield tuple(out)
def score_test_and_predicted(test_meg, meg_predicted):
    score = brain_score(test_meg, 0.03, ceiling_cutoff = None)
    return score(meg_predicted)

def get_score_for_fixed_test_meg(test_meg):
    score = brain_score(test_meg, 0.03, ceiling_cutoff = None)
    return score


def load_subject_weights(subject, full_models_loc = "./runs/full_models"):
    model_loc =  full_models_loc + f"/subject_{subject}/layer_3"
    full_model = load_pkl(model_loc + "/final_model.pkl")
    W = full_model.weights.T
    b = full_model.bias
    W_total = np.concat((W, b[:,None]), axis=1).T
    return W_total

def subject_rank_predictions(subject, r, test_features, save_loc = "./runs"):
    model_loc = save_loc + f"/subject_{subject}_rank_sweep_single/rank_{r}"
    low_rank_model = torch.load(model_loc + "/final_model.pt", weights_only=False)
    return low_rank_model.numpy_forward(test_features)

def subject_full_predictions(subject, test_features):
    W_total = load_subject_weights(subject)
    return block_gpu_multiply(test_features, W_total, 1000, 1000)

def rank_full_compare_p_value(subject, rank, test_feature, n_samples = 1000, time_window = 500):
    rank_prediction = subject_rank_predictions(subject, rank, test_feature)
    full_prediction = subject_full_predictions(subject, test_feature)
    bootstrap_score_diffs = []
    bootstrap_generator = block_resample_generator([rank_prediction, full_prediction, meg.transpose(1,0,2)], n_samples=n_samples, time_window = time_window, replacement=True)  
    for rank_prediction_bootstrap, full_prediction_bootstrap, meg_bootstrap in bootstrap_generator:
        score = brain_score(meg_bootstrap.transpose(1,0,2), 0.03)
        rank_score = score(rank_prediction_bootstrap)
        full_score = score(full_prediction_bootstrap)
        bootstrap_score_diffs.append(rank_score - full_score)
    boostraps_below_zero = sum(np.array(bootstrap_score_diffs) < 0)
    pvalue = max ( 1, boostraps_below_zero)/n_samples
    return pvalue, bootstrap_score_diffs

def rank_test_performance_p_value(subject, rank, test_feature, n_samples = 1000, time_window=500):
    rank_prediction = subject_rank_predictions(subject, rank, test_feature)
    true_score = brain_score(meg, 0.03, None)(rank_prediction)
    permutation_scores = []
    rank_permutation_generator = block_resample_generator([rank_prediction], n_samples=n_samples, time_window = time_window, replacement=False)
    meg_max_index_to_use = (meg.shape[1] // time_window)*time_window
    scorer = brain_score(meg[:,:meg_max_index_to_use,:], 0.03)
    for (rank_permutation,) in rank_permutation_generator:
        permutation_score = scorer(rank_permutation)
        permutation_scores.append(permutation_score)
    permutations_above_score = sum(np.array(permutation_scores) > true_score)
    pvalue = max ( 1, permutations_above_score)/n_samples
    return pvalue, true_score, permutation_scores

if __name__ == "__main__":
    task_id = int(os.environ["SLURM_ARRAY_TASK_ID"]) 
    ranks = list(range(1, 21))
    r = ranks[task_id]
    llm_features = { "name": "llama2","layer": 3,"context": 20,"pca": 0.95,"load": True,"delays": 40 }
    subject = "D"
    test_configs = subject_test_configs(subject, llm_features["name"])
    meg = np.stack(load_meg_targets(test_configs),  axis=0)
    meg_test_target = np.mean(meg, axis=0)
    score = brain_score(meg, 0.03, ceiling_cutoff = None)
    embeddings_loc = "./embeddings"
    save_folder = "./runs"
    llm_store_loc = f"{embeddings_loc}/embeddings_sweep/{llm_features['name']}/layer_{llm_features['layer']}_context_{llm_features['context_len']}"
    _, test_features = test_feature_loader(llm_features, lm_feature_map_loc=llm_store_loc, 
                        subject = subject, controls = [], delays = [], 
                        force_load=True, load_as_control=False)
    
    single_test_feature = test_features[0]
    os.makedirs(save_folder + "/significance_tests", exist_ok=True)
    save_loc = save_folder + "/significance_tests" + f"/subject_{subject}"
    
    pvalue, score_diffs = rank_full_compare_p_value(subject, r, single_test_feature, 1000)
    perf_pvalue, true_score, permutation_scores = rank_test_performance_p_value(subject, r, single_test_feature, 1000)
    
    with open(save_loc + "_score_diff.txt", "a") as f:
        f.writelines([f"rank {r} {pvalue}\n"])
        
    with open(save_loc + "_performance.txt", "a") as f:
        f.writelines([f"rank {r} {perf_pvalue}\n"])