# %%

from functools import partial
import torch

from config_utils import load_from_yaml

from datamodule import AllDatamodule

from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)

import argparse
import os

# %%
parser = argparse.ArgumentParser(description="pca model")
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--model", type=str, default="afo", help="model")
parser.add_argument("--data_dir", type=str, default="/data/VWET", help="data dir")
parser.add_argument("--subject_id", type=str, default="NSD_01", help="subject id")
parser.add_argument(
    "--save_dir", type=str, default="/data/results/pca", help="save dir"
)
args = parser.parse_args()

device = args.device
model = args.model
save_dir = args.save_dir
os.makedirs(save_dir, exist_ok=True)

# %%
cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")
cfg.DATASET.ROOT = args.data_dir
cfg.DATAMODULE.BATCH_SIZE = 300
cfg.DATASET.SUBJECT_LIST = [args.subject_id]
dm = AllDatamodule(cfg)
dm.setup()

# %%
if model == "dinov2":
    dinov2_vitb14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
    dinov2_vitb14.to(device)  # send the model to the chosen device ('cpu' or 'cuda')
    dinov2_vitb14.eval()  # set the model to evaluation mode, since you are not training it

    def dino_feature_extractor(x):
        out = dinov2_vitb14.get_intermediate_layers(x, n=[2, 5, 8, 11])
        out = {f"features.{k}": v for k, v in enumerate(out)}
        return out

    feature_extractor = dino_feature_extractor
elif model == "cliprn50":
    import clip

    model, _ = clip.load("RN50", device="cpu")
    model = model.visual
    model.eval()
    model = model.to(device)
    feature_extractor = create_feature_extractor(
        model, return_nodes=["layer1", "layer2", "layer3", "layer4"]
    )
elif model == "cliprn50x4":
    import clip

    model, _ = clip.load("RN50x4", device="cpu")
    model = model.visual
    model.eval()
    model = model.to(device)
    feature_extractor = create_feature_extractor(
        model, return_nodes=["layer1", "layer2", "layer3", "layer4"]
    )
elif model == "vitb":
    import torchvision

    model = torchvision.models.vit_b_16(
        weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1
    )
    model = model.to(device)
    model.eval()
    feature_extractor = create_feature_extractor(
        model,
        return_nodes=[
            "encoder.layers.encoder_layer_2",
            "encoder.layers.encoder_layer_5",
            "encoder.layers.encoder_layer_8",
            "encoder.layers.encoder_layer_11",
        ],
    )
elif model == "swinb":
    import torchvision

    model = torchvision.models.swin_v2_b(
        weights=torchvision.models.Swin_V2_B_Weights.DEFAULT
    )
    model = model.to(device)
    model.eval()
    feature_extractor = create_feature_extractor(
        model, return_nodes=["features.1", "features.2", "features.3", "features.4"]
    )
elif model == "imrn50":
    import torchvision

    model = torchvision.models.resnet50(
        weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2
    )
    model = model.to(device)
    model.eval()
    feature_extractor = create_feature_extractor(
        model, return_nodes=["layer1", "layer2", "layer3", "layer4"]
    )
elif model == "afo":
    from models import VEModel

    model = VEModel(
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )
    path = "/data/results/xgaa/yesgt_1/soup.pth"
    sd = torch.load(path, map_location="cpu")
    model.load_state_dict(sd, strict=False)
    model.eval()
    model = model.to(device)
    model.move_device()
    feature_extractor = model.get_intermidiate_outputs
else:
    raise NotImplementedError

batch_size = cfg.DATAMODULE.BATCH_SIZE
# %%
subject_id = args.subject_id
# %%
from sklearn.decomposition import IncrementalPCA
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr as corr

# %%
from tqdm import tqdm
import numpy as np


# model_layer = "features.2" #@param ["features.2", "features.5", "features.7", "features.9", "features.12", "classifier.2", "classifier.5", "classifier.6"] {allow-input: true}
# feature_extractor = create_feature_extractor(model, return_nodes=[model_layer])
# # %%
# def fit_pca(feature_extractor, dataloader, max_iter=1):
#     # Define PCA parameters
#     pca = IncrementalPCA(n_components=256, batch_size=batch_size)


#     # Fit PCA to batch
#     for i, d in tqdm(enumerate(dataloader), total=len(dataloader)):
#         if i >= max_iter:
#             break
#         # Extract features
#         img = d[0].to(device)
#         ft = feature_extractor(img)
#         # Flatten the features
#         ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
#         # Fit PCA to batch
#         pca.partial_fit(ft.detach().cpu().numpy().astype(np.float32))
#     return pca
def fit_pca(feature_extractor, dataloader, max_iter=1):
    # Define PCA parameters
    # Fit PCA to batch
    for i, d in tqdm(enumerate(dataloader), total=len(dataloader)):
        # Extract features
        img = d[0].to(device)
        ft = feature_extractor(img)
        # Flatten the features
        ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
        # Fit PCA to batch
        pca = torch.pca_lowrank(ft, q=256)
        break
    return pca


# %%
with torch.no_grad():
    pca = fit_pca(
        feature_extractor,
        dm.train_dataloader(subject=subject_id, shuffle=False),
        max_iter=1,
    )


# torch.save(pca, os.path.join(save_dir, f"{model}_{subject_id}_pca.pt"))
# %%
# %%
# def extract_features(feature_extractor, dataloader, pca):
#     features = []
#     for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
#         # Extract features
#         img = d[0].to(device)
#         ft = feature_extractor(img)
#         # Flatten the features
#         ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
#         # Apply PCA transform
#         ft = pca.transform(ft.cpu().detach().numpy().astype(np.float32))
#         features.append(ft)
#     return np.vstack(features)
def extract_features(feature_extractor, dataloader, pca):
    features = []
    for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
        # Extract features
        img = d[0].to(device)
        ft = feature_extractor(img)
        # Flatten the features
        ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
        # Apply PCA transform
        # ft = pca.transform(ft.cpu().detach().numpy().astype(np.float32))
        ft = torch.matmul(ft, pca[-1])
        features.append(ft)
    return torch.vstack(features)


# %%
with torch.no_grad():
    features_train = extract_features(
        feature_extractor, dm.train_dataloader(subject=subject_id, shuffle=False), pca
    )
    features_val = extract_features(
        feature_extractor, dm.val_dataloader(subject=subject_id), pca
    )
    features_test = extract_features(
        feature_extractor, dm.test_dataloader(subject=subject_id), pca
    )


# %%
def get_ys(dataloader):
    ys = []
    for d in dataloader:
        ys.append(torch.stack(d[1]))
    ys = torch.cat(ys, dim=0)
    # ys = ys.cpu().numpy().astype(np.float32)
    return ys


# %%
ys_train = get_ys(dm.train_dataloader(subject=subject_id, shuffle=False))
ys_val = get_ys(dm.val_dataloader(subject=subject_id))
ys_test = get_ys(dm.test_dataloader(subject=subject_id))

# %%
# reg = LinearRegression().fit(features_train, ys_train)
# %%
# features_train = torch.from_numpy(features_train).float().to(device).unsqueeze(0)
# features_val = torch.from_numpy(features_val).float().to(device).unsqueeze(0)
# features_test = torch.from_numpy(features_test).float().to(device).unsqueeze(0)
# ys_train = torch.from_numpy(ys_train).float().to(device).unsqueeze(0)
# ys_val = torch.from_numpy(ys_val).float().to(device).unsqueeze(0)
# ys_test = torch.from_numpy(ys_test).float().to(device).unsqueeze(0)
features_train = features_train.view(1, -1, 256).to(device)
features_val = features_val.view(1, -1, 256).to(device)
features_test = features_test.view(1, -1, 256).to(device)
ys_train = ys_train.view(1, -1, ys_train.shape[-1]).to(device)
ys_val = ys_val.view(1, -1, ys_val.shape[-1]).to(device)
ys_test = ys_test.view(1, -1, ys_test.shape[-1]).to(device)
# %%
print(features_train.shape, ys_train.shape)
# %%
torch.cuda.empty_cache()
# %%
from ridge_regression import Ridge

ridge = Ridge(alpha=0)
ridge.fit(features_train, ys_train)
# %%
ys_train_pred = ridge.predict(features_train)
ys_val_pred = ridge.predict(features_val)
ys_test_pred = ridge.predict(features_test)
# %%
from metrics import vectorized_correlation

p_train = vectorized_correlation(ys_train_pred[0], ys_train[0])
p_val = vectorized_correlation(ys_val_pred[0], ys_val[0])
p_test = vectorized_correlation(ys_test_pred[0], ys_test[0])
# %%
model = args.model
print("Model:", model)
print("Train:", p_train.mean())
print("Val:", p_val.mean())
print("Test:", p_test.mean())
# %%
torch.save(p_train, os.path.join(save_dir, f"{model}_{subject_id}_p_train.pt"))
torch.save(p_val, os.path.join(save_dir, f"{model}_{subject_id}_p_val.pt"))
torch.save(p_test, os.path.join(save_dir, f"{model}_{subject_id}_p_test.pt"))
# %%
