from omegaconf import DictConfig, OmegaConf
import hydra
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import os
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm as tqdm
import argparse
import logging
import json
import sys
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import models
import pickle

logging.basicConfig(level = logging.INFO)

log = logging.getLogger(__name__)
USER = os.getenv('USER')
if USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/BrainBitsWIP/data/predicted_features/')
elif USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/projects/brainbits/BrainBitsWIP/data/predicted_features/')
else:
    raise ValueError(f"Unknown user {USER}")
BD_ROOT_PATH = Path('/storage/user1/brain-diffuser')


class fMRI2latent(Dataset):
    def __init__(self, fmri_data, vdvae_embeds):
        self.fmri_data = fmri_data
        self.vdvae_embeds = vdvae_embeds

    def __len__(self):
        return len(self.fmri_data)

    def __getitem__(self, idx):
        return {"inputs": torch.FloatTensor(self.fmri_data[idx]), 
                "vdvae_targets": torch.FloatTensor(self.vdvae_embeds[idx]),
               }

class BottleneckLinear(nn.Module):
    def __init__(self, input_size, bottleneck_size, d_vdvae, norm_mean_train, norm_scale_train, embed_w=None, multi_gpu=False, reg_weights=None, pca_mean=0):
        super().__init__()

        self.fmri2embed = nn.Sequential(nn.Linear(input_size, bottleneck_size, bias=False),
                                        #nn.Linear(bottleneck_size, bottleneck_size),
                                        #torch.nn.ReLU(),
                                        #nn.Linear(bottleneck_size, bottleneck_size),
                                       )
        #self.fmri2embed[0].weight = torch.nn.Parameter(torch.FloatTensor(embed_w))#TODO
        self.vdvae_embed = nn.Linear(bottleneck_size, d_vdvae)
        #self.vdvae_embed.weight = torch.nn.Parameter(torch.FloatTensor(reg_weights["weight"]))#TODO
        #self.vdvae_embed.bias = torch.nn.Parameter(torch.FloatTensor(reg_weights["bias"]))#TODO
        self.norm_mean_train = torch.FloatTensor(norm_mean_train)
        self.norm_scale_train = torch.FloatTensor(norm_scale_train)
        self.pca_mean = torch.FloatTensor(pca_mean)

    def fmri_scaling(self, fmri_inputs):
        fmri_inputs = fmri_inputs/300#remember to divide by 300 when getting the means
        fmri_inputs = (fmri_inputs - self.norm_mean_train.to(fmri_inputs.device)) / self.norm_scale_train.to(fmri_inputs.device)
        return fmri_inputs
 
    def forward(self, fmri_inputs):
        centered_fmri_inputs = fmri_inputs - self.pca_mean.to(fmri_inputs.device)
        bottleneck_mapping = self.fmri2embed(centered_fmri_inputs)
        scaled_mapping = self.fmri_scaling(bottleneck_mapping)
        vdvae_mapping = self.vdvae_embed(scaled_mapping)
        #vdvae_mapping = self.vdvae_embed(fmri_inputs)
        return vdvae_mapping

def get_loss(criterion, vdvae_preds, vdvae_targets, batch, reg_cfg, n_batch):
    vdvae_loss = criterion(vdvae_preds, vdvae_targets)
    loss = vdvae_loss
    return loss, vdvae_loss

def get_eval_loss(criterion, model, val_loader, reg_cfg):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for batch in tqdm(val_loader):
            inputs = batch["inputs"].to(reg_cfg.device)
            n_batch = inputs.shape[0]
            vdvae_preds = model(inputs) 
            vdvae_targets = batch["vdvae_targets"].to(reg_cfg.device)
            loss, vdvae_loss = get_loss(criterion, vdvae_preds, vdvae_targets, batch, reg_cfg, n_batch)
            total_loss += loss.item()
    return total_loss/len(val_loader)

def train_linear_mapping(model, train_loader, val_loader, reg_cfg):
    if reg_cfg.optim == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=reg_cfg.lr, momentum=0.0, weight_decay=0.001)
    elif reg_cfg.optim == "Adam":
        optimizer = optim.AdamW(model.parameters(), lr=reg_cfg.lr, weight_decay=0.001)
    else:
        print("no optim")

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
    min_eval_loss = 100
    best_model = model
    criterion = nn.MSELoss()
    named_parameters = list(model.named_parameters())

    freeze_embed = False
    lr_1, lr_2 = reg_cfg.lr, reg_cfg.lr
    for epoch in range(reg_cfg.n_epochs):
        if epoch%20==0:
            freeze_embed = not freeze_embed
            if freeze_embed:
                lr_1 = optimizer.param_groups[0]['lr']
                optimizer = optim.AdamW(model.parameters(), lr=lr_2, weight_decay=0.001)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
            else:
                lr_2 = optimizer.param_groups[0]['lr']
                optimizer = optim.AdamW(model.parameters(), lr=lr_1, weight_decay=0.001)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        if freeze_embed:
            for name, param in model.named_parameters():
                if 'fmri2embed' in name:
                    param.requires_grad = False
                elif 'vdvae_embed' in name:
                    param.requires_grad = True
        else:
            for name, param in model.named_parameters():
                if 'fmri2embed' in name:
                    param.requires_grad = True
                elif 'vdvae_embed' in name:
                    param.requires_grad = False

        with tqdm(total=len(train_loader)) as bar:
            bar.set_description(f"Epoch {epoch}")
            train_loss, train_vdvae_loss, train_text_loss = 0, 0, 0
            for batch in train_loader:
                inputs = batch["inputs"].to(reg_cfg.device)
                n_batch = inputs.shape[0]
                #targets = batch["targets"].to(reg_cfg.device) #TODO
                #targets = batch["targets"].cuda(1)
                optimizer.zero_grad()
                vdvae_preds = model(inputs)
                vdvae_targets = batch["vdvae_targets"].to(reg_cfg.device)

                loss, vdvae_loss = get_loss(criterion, vdvae_preds, vdvae_targets, batch, reg_cfg, n_batch)
                loss.backward()
                ##print(loss.item())
                optimizer.step()
                bar.set_postfix({"v":float(vdvae_loss)})
                bar.update()
                train_loss += float(loss)
                train_vdvae_loss += float(vdvae_loss)
                #if epoch==18:
                #    import pdb; pdb.set_trace()
            
            avg_loss = train_loss/len(train_loader)
            avg_vdvae_loss = train_vdvae_loss/len(train_loader)

            eval_loss = get_eval_loss(criterion, model, val_loader, reg_cfg)
            bar.set_postfix({"eval": eval_loss, "mse":avg_loss, "v": avg_vdvae_loss})
        if eval_loss < min_eval_loss:
            min_eval_loss = eval_loss
            best_model = model
        scheduler.step(avg_loss)
    return model#TODO
    #return best_model#TODO

def eval_model(model, test_loader, device):
    model.eval()
    with torch.no_grad():
        all_vdvae_preds = []
        for batch in tqdm(test_loader):
            inputs = batch["inputs"].to(device)
            vdvae_preds = model(inputs)
            all_vdvae_preds.append(vdvae_preds)
        all_vdvae_preds = torch.cat(all_vdvae_preds)
    all_vdvae_preds = all_vdvae_preds.cpu().detach().numpy()
    return all_vdvae_preds

def scale_latents(pred_test_latent, train_latents):
    std_norm_test_latent = (pred_test_latent - np.mean(pred_test_latent,axis=0)) / np.std(pred_test_latent,axis=0)
    pred_latents = std_norm_test_latent * np.std(train_latents,axis=0) + np.mean(train_latents,axis=0)
    return pred_latents

def get_vdvae_targets(sub):
    log.info("Getting VDVAE targets")

    #get latent targets
    nsd_path = 'data/extracted_features/subj{:02d}/nsd_vdvae_features_31l.npz'.format(sub)
    nsd_features = np.load(BD_ROOT_PATH / nsd_path)

    train_latents = nsd_features['train_latents']
    test_latents = nsd_features['test_latents']

    return train_latents, test_latents

def get_fmri_inputs(sub):
    #get fmri inputs
    log.info("Getting fMRI inputs")
    train_path = 'data/processed_data/subj{:02d}/nsd_train_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    train_fmri = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/processed_data/subj{:02d}/nsd_test_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    test_fmri = np.load(BD_ROOT_PATH / test_path)

    #NOTE: this has all been moved into the NN forward
    #train_fmri = train_fmri/300
    #test_fmri = test_fmri/300

    #norm_mean_train = np.mean(train_fmri, axis=0)
    #norm_scale_train = np.std(train_fmri, axis=0, ddof=1)
    #train_fmri = (train_fmri - norm_mean_train) / norm_scale_train
    #test_fmri = (test_fmri - norm_mean_train) / norm_scale_train
    return train_fmri, test_fmri

def save_preds(arr, sub, bottleneck_size, out_name):
    save_path_dir = SAVE_ROOT_PATH / f'subj_{sub}/bbits_{bottleneck_size}/'
    save_path_dir.mkdir(parents=True, exist_ok=True)
    np.save(save_path_dir / f"{out_name}.npy", arr)

def train_all(sub, bottleneck_size, train_fmri, test_fmri, reg_cfg):
    vdvae_embeds_train, vdvae_embeds_test = get_vdvae_targets(sub)

    #vdvae_embeds_train = (vdvae_embeds_train - np.mean(vdvae_embeds_train, axis=0))/np.std(vdvae_embeds_train, axis=0)#TODO scaling happens here
    n_train, d_vdvae = vdvae_embeds_train.shape
    n_test, _, = vdvae_embeds_test.shape

    val_split = 0.15 #TODO hardcode

    all_train_data = fMRI2latent(train_fmri, vdvae_embeds_train)
    train_idx, val_idx = train_test_split(list(range(len(all_train_data))), test_size=val_split)
    
    pca_components = np.load("/storage/user1/brain-diffuser/data/pca_reduced/subj01/pca_50/components.npy")

    pca_mean = train_fmri.mean(axis=0)
    transformed_train = np.dot(train_fmri - pca_mean, pca_components.T)
    print(abs(transformed_train).mean())
    #train_input_arr = train_fmri[train_idx]
    norm_mean_train = np.mean(transformed_train/300, axis=0)#we divide by 300 here
    norm_scale_train = np.std(transformed_train/300, axis=0, ddof=1)

    train_data = Subset(all_train_data, train_idx)
    val_data = Subset(all_train_data, val_idx)
    train_loader = DataLoader(train_data, batch_size=reg_cfg.batch_size, shuffle=False)#TODO
    val_loader = DataLoader(val_data, batch_size=reg_cfg.batch_size, shuffle=False)#TODO

    #bottleneck_size = d_vdvae  
    with open("/storage/user1/brain-diffuser/data/predicted_features/subj01/pca_reduced/pca_50/vdvae_regression_weights.pkl", "rb") as f:
        reg_weights = pickle.load(f)

    model = BottleneckLinear(train_fmri.shape[-1], bottleneck_size, d_vdvae, norm_mean_train, norm_scale_train, embed_w=pca_components, reg_weights=reg_weights, pca_mean=pca_mean)

    model = model.to(reg_cfg.device)
    #if device=="cuda":
    #    model= nn.DataParallel(model)
    log.info("Training fMRI2latent mapping")

    model = train_linear_mapping(model, train_loader, val_loader, reg_cfg)
    
    log.info("fMRI2latent test evaluation")
    test_data = fMRI2latent(test_fmri, vdvae_embeds_test)
    test_loader = DataLoader(test_data, batch_size=reg_cfg.batch_size, shuffle=False)#TODO

    vdvae_preds = eval_model(model, test_loader, reg_cfg.device)#, test_fmri)

    scaled_vdvae_preds = scale_latents(vdvae_preds, vdvae_embeds_train)

    save_preds(scaled_vdvae_preds,sub, bottleneck_size, "vdvae_preds")

    save_preds(vdvae_preds,sub, bottleneck_size, "unscaled_vdvae_preds")

@hydra.main(config_path="conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Run testing for all electrodes in all test_subjects")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    out_dir = os.getcwd()
    log.info(f'Working directory {os.getcwd()}')
    if "out_dir" in cfg.exp:
        out_dir = cfg.exp.out_dir
    log.info(f'Output directory {out_dir}')

    sub = cfg.exp["sub"]

    train_fmri, test_fmri = get_fmri_inputs(sub)
    
    bottleneck_sizes = cfg.exp["bottlenecks"]
    reg_cfg = cfg.exp.reg
    for bottleneck_size in bottleneck_sizes:
        train_all(sub, bottleneck_size, train_fmri, test_fmri, reg_cfg)

if __name__=="__main__":
    # _debug = '''train.py +exp=latent_reg ++exp.bottlenecks=[5] ++exp.reg.batch_size=128 ++exp.reg.n_epochs=1 ++exp.reg.optim="SGD" ++exp.reg.device="cpu"'''
    # sys.argv = _debug.split(" ")
    main()


