import wandb
import copy
from omegaconf import DictConfig, OmegaConf
import hydra
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import os
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm as tqdm
import argparse
import logging
import json
import sys
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import models

logging.basicConfig(level = logging.INFO)

log = logging.getLogger(__name__)
USER = os.getenv('USER')
if USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/BrainBitsWIP/data/predicted_features/')
elif USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/projects/brainbits/BrainBitsWIP/data/predicted_features/')
else:
    raise ValueError(f"Unknown user {USER}")
BD_ROOT_PATH = Path('/storage/user1/brain-diffuser')


class fMRI2latent(Dataset):
    def __init__(self, fmri_data, vdvae_embeds, clip_text_embeds, clip_vision_embeds):
        self.fmri_data = fmri_data
        self.vdvae_embeds = vdvae_embeds
        self.clip_text_embeds = clip_text_embeds
        self.clip_vision_embeds = clip_vision_embeds

    def __len__(self):
        return len(self.fmri_data)

    def __getitem__(self, idx):
        return {"inputs": torch.FloatTensor(self.fmri_data[idx]), 
                "vdvae_targets": torch.FloatTensor(self.vdvae_embeds[idx]),
                "clip_text_targets": torch.FloatTensor(self.clip_text_embeds[idx]),
                "clip_vision_targets": torch.FloatTensor(self.clip_vision_embeds[idx])}

class BottleneckLinear(nn.Module):
    def __init__(self, input_size, bottleneck_size, d_vdvae, d_clip, n_text, n_vision, cfg, embed_w=None, multi_gpu=False):
        super().__init__()

        self.fmri2embed = nn.Sequential(nn.Linear(input_size, bottleneck_size, bias=False),
                                       )
        if 'pca_preload' in cfg and cfg.pca_preload:
            self.fmri2embed[0].weight = torch.nn.Parameter(torch.FloatTensor(embed_w))
        else:
            self.fmri2embed = nn.Sequential(nn.Linear(input_size, bottleneck_size, bias=True),)

        self.vdvae_embed = nn.Linear(bottleneck_size, d_vdvae)
        self.clip_text_embed = nn.Linear(bottleneck_size, n_text*d_clip) 
        self.clip_vision_embed = nn.Linear(bottleneck_size, n_vision*d_clip)

    def forward(self, fmri_inputs):
        bottleneck_mapping = self.fmri2embed(fmri_inputs)
        vdvae_mapping = self.vdvae_embed(bottleneck_mapping)
        clip_text_mapping = self.clip_text_embed(bottleneck_mapping)
        clip_vision_mapping = self.clip_vision_embed(bottleneck_mapping)
        return vdvae_mapping, clip_text_mapping, clip_vision_mapping, bottleneck_mapping

def get_loss(vdvae_preds, clip_text_preds, clip_vision_preds, batch, reg_cfg, n_batch):
    criterion = nn.MSELoss()
    clip_text_targets = batch["clip_text_targets"].to(reg_cfg.device)
    clip_vision_targets = batch["clip_vision_targets"].to(reg_cfg.device)
    vdvae_targets = batch["vdvae_targets"].to(reg_cfg.device)

    vdvae_loss = criterion(vdvae_preds, vdvae_targets)
    clip_text_loss = criterion(clip_text_preds, clip_text_targets.reshape(n_batch,-1))
    clip_vision_loss = criterion(clip_vision_preds, clip_vision_targets.reshape(n_batch,-1))
    loss = vdvae_loss + 2*clip_text_loss + 4*clip_vision_loss
    return loss, vdvae_loss, clip_text_loss, clip_vision_loss

def get_eval_loss(model, val_loader, reg_cfg):
    criterion = nn.MSELoss()
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for batch in tqdm(val_loader):
            inputs = batch["inputs"].to(reg_cfg.device)
            n_batch = inputs.shape[0]
            vdvae_preds, clip_text_preds, clip_vision_preds, _ = model(inputs) 
            loss, vdvae_loss, clip_text_loss, clip_vision_loss = get_loss(vdvae_preds, clip_text_preds, clip_vision_preds, batch, reg_cfg, n_batch)
            total_loss += loss.item()
    return total_loss/len(val_loader)

def scale_preds(vdvae_preds, train_stats):
    train_mean, train_std = train_stats
    vdvae_preds_arr = vdvae_preds.detach()
    epsilon = 0.0001
    std_norm_test_latent = (vdvae_preds - torch.mean(vdvae_preds_arr,axis=0)) / (torch.nan_to_num(torch.std(vdvae_preds_arr,axis=0),nan=epsilon))
    pred_latents = std_norm_test_latent * torch.FloatTensor(train_std).to(vdvae_preds.device) + torch.FloatTensor(train_mean).to(vdvae_preds.device)
    return pred_latents

def train_linear_mapping(model, train_loader, val_loader, reg_cfg, train_stats, save_path_dir):
    if reg_cfg.optim == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=reg_cfg.lr, momentum=0.0, weight_decay=0.001)
    elif reg_cfg.optim == "Adam":
        optimizer = optim.AdamW(model.parameters(), lr=reg_cfg.lr, weight_decay=0.001)
    else:
        print("no optim")

    iterative_on = True
    if 'use_iterative' in reg_cfg:
        iterative_on = reg_cfg.use_iterative

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
    min_eval_loss = 100
    best_model = copy.deepcopy(model)
    criterion = nn.MSELoss()
    named_parameters = list(model.named_parameters())

    freeze_embed = False
    lr_1, lr_2 = reg_cfg.lr, reg_cfg.lr
    n_epoch_per_stage = reg_cfg.get("n_epoch_per_stage", 20)
    for epoch in range(reg_cfg.n_epochs):
        if iterative_on and epoch%n_epoch_per_stage==0:
            freeze_embed = not freeze_embed
            if freeze_embed:
                lr_1 = optimizer.param_groups[0]['lr']
                optimizer = optim.AdamW(model.parameters(), lr=lr_2, weight_decay=0.001)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
            else:
                lr_2 = optimizer.param_groups[0]['lr']
                optimizer = optim.AdamW(model.parameters(), lr=lr_1, weight_decay=0.001)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        if iterative_on and freeze_embed:
            for name, param in model.named_parameters():
                if 'fmri2embed' in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True
        elif iterative_on:
            for name, param in model.named_parameters():
                if 'fmri2embed' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False

        if epoch%10==0 and reg_cfg.get("save_checkpoints", False):
            torch.save(model.state_dict(), save_path_dir / f"weights_{epoch}.pth")

        with tqdm(total=len(train_loader)) as bar:
            bar.set_description(f"Epoch {epoch}")
            train_loss, train_vdvae_loss, train_text_loss = 0, 0, 0
            for batch in train_loader:
                inputs = batch["inputs"].to(reg_cfg.device)
                n_batch = inputs.shape[0]
                #targets = batch["targets"].to(reg_cfg.device) #TODO
                #targets = batch["targets"].cuda(1)
                optimizer.zero_grad()

                vdvae_preds, clip_text_preds, clip_vision_preds, _ = model(inputs)
                loss, vdvae_loss, clip_text_loss, clip_vision_loss = get_loss(vdvae_preds, clip_text_preds, clip_vision_preds, batch, reg_cfg, n_batch)
                loss.backward()
                optimizer.step()
                bar.set_postfix({"v":float(vdvae_loss)})
                bar.update()
                train_loss += float(loss)
                train_vdvae_loss += float(vdvae_loss)
            
            avg_loss = train_loss/len(train_loader)
            avg_vdvae_loss = train_vdvae_loss/len(train_loader)

            eval_loss = get_eval_loss(model, val_loader, reg_cfg)
            bar.set_postfix({"eval": eval_loss, "mse":avg_loss, "v": avg_vdvae_loss})
            wandb.log({"val_loss": eval_loss, "train_loss": avg_loss})
        if eval_loss < min_eval_loss:
            min_eval_loss = eval_loss
            best_model = copy.deepcopy(model)
            arr = model.fmri2embed[0].weight.detach().cpu().numpy()
            print(arr.sum())
            print("eval loss is better")
        scheduler.step(avg_loss)

    #return model#TODO
    return best_model#TODO

def eval_model(model, test_loader, device): #, test_fmri):#TODO):
    model.eval()
    with torch.no_grad():
        all_vdvae_preds, all_clip_text_preds, all_clip_vision_preds, all_intermediates = [], [], [], []
        for batch in tqdm(test_loader):
            inputs = batch["inputs"].to(device)
            vdvae_preds, clip_text_preds, clip_vision_preds, intermediate = model(inputs) #TODO
            all_vdvae_preds.append(vdvae_preds)
            all_clip_text_preds.append(clip_text_preds)
            all_clip_vision_preds.append(clip_vision_preds)
            all_intermediates.append(intermediate)
        all_vdvae_preds = torch.cat(all_vdvae_preds)
        all_clip_text_preds = torch.cat(all_clip_text_preds)
        all_clip_vision_preds = torch.cat(all_clip_vision_preds)
        all_intermediates = torch.cat(all_intermediates)
    all_vdvae_preds = all_vdvae_preds.cpu().detach().numpy()
    all_clip_text_preds = all_clip_text_preds.cpu().detach().numpy()
    all_clip_vision_preds = all_clip_vision_preds.cpu().detach().numpy()
    all_intermediates = all_intermediates.cpu().detach().numpy()
    return all_vdvae_preds, all_clip_text_preds, all_clip_vision_preds, all_intermediates

def scale_latents(pred_test_latent, train_latents):
    std_norm_test_latent = (pred_test_latent - np.mean(pred_test_latent,axis=0)) / np.std(pred_test_latent,axis=0)
    pred_latents = std_norm_test_latent * np.std(train_latents,axis=0) + np.mean(train_latents,axis=0)
    return pred_latents

def get_vdvae_targets(sub):
    log.info("Getting VDVAE targets")

    #get latent targets
    nsd_path = 'data/extracted_features/subj{:02d}/nsd_vdvae_features_31l.npz'.format(sub)
    nsd_features = np.load(BD_ROOT_PATH / nsd_path)

    train_latents = nsd_features['train_latents']
    test_latents = nsd_features['test_latents']

    return train_latents, test_latents

def get_fmri_inputs(sub):
    #get fmri inputs
    log.info("Getting fMRI inputs")
    train_path = 'data/processed_data/subj{:02d}/nsd_train_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    train_fmri = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/processed_data/subj{:02d}/nsd_test_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    test_fmri = np.load(BD_ROOT_PATH / test_path)

    train_fmri = train_fmri/300
    test_fmri = test_fmri/300

    norm_mean_train = np.mean(train_fmri, axis=0)
    norm_scale_train = np.std(train_fmri, axis=0, ddof=1)
    train_fmri = (train_fmri - norm_mean_train) / norm_scale_train
    test_fmri = (test_fmri - norm_mean_train) / norm_scale_train
    return train_fmri, test_fmri

def save_preds(arr, sub, bottleneck_size, out_name):
    save_path_dir = SAVE_ROOT_PATH / f'subj{sub:02d}/train_single/train_single_{bottleneck_size}/'
    save_path_dir.mkdir(parents=True, exist_ok=True)
    np.save(save_path_dir / f"{out_name}.npy", arr)

def save_weights(model, sub, bottleneck_size):
    arr = model.fmri2embed[0].weight.detach().cpu().numpy()
    save_path_dir = SAVE_ROOT_PATH / f'subj{sub:02d}/train_single/train_single_{bottleneck_size}/'
    save_path_dir.mkdir(parents=True, exist_ok=True)
    np.save(save_path_dir / f"compression_weights.npy", arr)
    torch.save(model.state_dict(), save_path_dir / f"final_weights.pth")

def get_clip_text_latents(sub):
    log.info("Getting CLIP text targets")

    #get latent targets
    train_path = 'data/extracted_features/subj{:02d}/nsd_cliptext_train.npy'.format(sub)
    train_clip = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/extracted_features/subj{:02d}/nsd_cliptext_test.npy'.format(sub)
    test_clip = np.load(BD_ROOT_PATH / test_path)
    out_name = "clip_text_preds"

    return train_clip, test_clip

def get_clip_vision_latents(sub):
    log.info("Getting CLIP vision targets")

    #get latent targets
    train_path = 'data/extracted_features/subj{:02d}/nsd_clipvision_train.npy'.format(sub)
    train_clip = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/extracted_features/subj{:02d}/nsd_clipvision_test.npy'.format(sub)
    test_clip = np.load(BD_ROOT_PATH / test_path)
    out_name = "clip_vision_preds"

    return train_clip, test_clip


def train_all(sub, bottleneck_size, train_fmri, test_fmri, reg_cfg, save_path_dir):
    clip_text_train, clip_text_test = get_clip_text_latents(sub)
    clip_vision_train, clip_vision_test = get_clip_vision_latents(sub)
    vdvae_embeds_train, vdvae_embeds_test = get_vdvae_targets(sub)

    n_train, n_text, d_clip_embed = clip_text_train.shape
    _, n_vision, _ = clip_vision_train.shape

    _, d_vdvae = vdvae_embeds_train.shape
    n_test, _, = vdvae_embeds_test.shape

    val_split = 0.15 #TODO hardcode
    all_train_data = fMRI2latent(train_fmri, vdvae_embeds_train, clip_text_train, clip_vision_train)
    train_idx, val_idx = train_test_split(list(range(len(all_train_data))), test_size=val_split)
    
    train_input_arr = train_fmri[train_idx]
    pca = PCA(n_components=bottleneck_size)
    pca.fit(train_input_arr)
    pca_components = pca.components_

    train_latent_arr = vdvae_embeds_train[train_idx]
    train_std = np.std(train_latent_arr, axis=0)
    train_mean = np.mean(train_latent_arr, axis=0)
    train_stats = (train_mean, train_std)

    train_data = Subset(all_train_data, train_idx)
    val_data = Subset(all_train_data, val_idx)
    train_loader = DataLoader(train_data, batch_size=reg_cfg.batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=reg_cfg.batch_size, shuffle=True)

    model = BottleneckLinear(train_fmri.shape[-1], bottleneck_size, d_vdvae, d_clip_embed, n_text, n_vision, reg_cfg, embed_w=pca_components)

    model = model.to(reg_cfg.device)
    #if device=="cuda":
    #    model= nn.DataParallel(model)
    log.info("Training fMRI2latent mapping")

    model = train_linear_mapping(model, train_loader, val_loader, reg_cfg, train_stats, save_path_dir)
    
    log.info("fMRI2latent test evaluation")
    test_data = fMRI2latent(test_fmri, vdvae_embeds_test, clip_text_test, clip_vision_test)
    test_loader = DataLoader(test_data, batch_size=reg_cfg.batch_size, shuffle=False)

    vdvae_preds, text_preds, vision_preds, test_intermediates = eval_model(model, test_loader, reg_cfg.device)#, test_fmri)
    text_preds = text_preds.reshape(n_test, n_text, d_clip_embed)
    vision_preds = vision_preds.reshape(n_test, n_vision, d_clip_embed)

    scaled_vdvae_preds = scale_latents(vdvae_preds, vdvae_embeds_train)
    scaled_vision_preds = scale_latents(vision_preds, clip_vision_train)
    scaled_text_preds = scale_latents(text_preds, clip_text_train)

    save_preds(scaled_vdvae_preds,sub, bottleneck_size, "vdvae_preds")
    save_preds(scaled_text_preds, sub, bottleneck_size, "clip_text_preds")
    save_preds(scaled_vision_preds, sub, bottleneck_size, "clip_vision_preds")

    save_preds(vdvae_preds,sub, bottleneck_size,   "unscaled_vdvae_preds")
    save_preds(text_preds, sub, bottleneck_size,   "unscaled_clip_text_preds")
    save_preds(vision_preds, sub, bottleneck_size, "unscaled_clip_vision_preds")

    save_preds(test_intermediates, sub, bottleneck_size, "test_intermediates")

    all_train_loader = DataLoader(all_train_data, batch_size=reg_cfg.batch_size, shuffle=True)#TODO really I should use same train and val split
    _, _, _, train_intermediates = eval_model(model, all_train_loader, reg_cfg.device)#, test_fmri)
    print(train_intermediates.shape)
    save_preds(train_intermediates, sub, bottleneck_size, "train_intermediates")

    save_weights(model, sub, bottleneck_size)

@hydra.main(config_path="conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Run testing for all electrodes in all test_subjects")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    out_dir = os.getcwd()
    log.info(f'Working directory {os.getcwd()}')
    if "out_dir" in cfg.exp:
        out_dir = cfg.exp.out_dir
    log.info(f'Output directory {out_dir}')

    wandb.init(
        # set the wandb project where this run will be logged
        project="brainbits",

        # track hyperparameters and run metadata
        config = OmegaConf.to_container(cfg, resolve=True)
    )

    sub = cfg.exp["sub"]

    train_fmri, test_fmri = get_fmri_inputs(sub)
    
    bottleneck_sizes = cfg.exp["bottlenecks"]
    reg_cfg = cfg.exp.reg

    for bottleneck_size in bottleneck_sizes:
        save_path_dir = SAVE_ROOT_PATH / f'subj{sub:02d}/train_single/train_single_{bottleneck_size}/'
        save_path_dir.mkdir(parents=True, exist_ok=True)
        train_all(sub, bottleneck_size, train_fmri, test_fmri, reg_cfg, save_path_dir)
    wandb.finish()

if __name__=="__main__":
    # _debug = '''train.py +exp=latent_reg ++exp.bottlenecks=[5] ++exp.reg.batch_size=128 ++exp.reg.n_epochs=1 ++exp.reg.optim="SGD" ++exp.reg.device="cpu"'''
    # sys.argv = _debug.split(" ")
    main()


