# %%
from tokenize import group
import torch
import random
import argparse
import numpy as np
from pathlib import Path
import ipdb as pdb
import os, pwd, yaml
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from dataset import DANS
import warnings
import wandb
warnings.filterwarnings('ignore')
from gen_dataset import gen_da_data_ortho
import quadprog
from components.beta import BetaVAE_MLP
# from GEMlearning import 
from sssa import SSA
import ipdb as pdb
import torch.nn as nn
import torch.nn.init as init
import pytorch_lightning as pl
import torch.distributions as D
from torch.nn import functional as F
from components.beta import BetaVAE_MLP
from components.transforms import NormalizingFlow
from metrics.correlation import compute_mcc
from metrics.block import compute_r2
import random as rd
import quadprog
import torch.optim as opt
import matplotlib.pyplot as plt
from themode_withspline2 import themodel
from glob import glob
import os


def load_yaml(filename):
    """
    Load and print YAML config files
    """
    with open(filename, 'r') as stream:
        file = yaml.safe_load(stream)
        return file

# %%
def split_data(train_data, test_data, n_domains, n_obs_perdomain_train, n_obs_perdomain_test):
    output_train = []
    output_test = []
    for i in range(n_domains):
        segID_train = range(n_obs_perdomain_train * i, n_obs_perdomain_train * (i + 1))
        segID_test = range(n_obs_perdomain_test * i, n_obs_perdomain_test * (i + 1))
        output_train.append({"y":train_data["y"][segID_train], 'x':train_data['x'][segID_train], "c":train_data['c'][segID_train]})
        output_test.append({"y":test_data["y"][segID_test], 'x':test_data['x'][segID_test], "c":test_data['c'][segID_test]})

    return output_train, output_test



def get_model_dir(data_name, seed, model_type, mem, dom):

    script_dir = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
    rel_path = os.path.join(script_dir, 'data')
    if model_type == "lighting":
        abs_file_path = os.path.join(rel_path, data_name, 'seed_' + str(seed), 'model/joint',data_name)
        model_names = glob(abs_file_path + '/*pt')
        # dic = {
        #     2: model_names[0],
        #     5: model_names[1],
        #     3: model_names[2],
        #     4: model_names[3],
        #     1: model_names[4] 
        # }
        # model_dir = dic[seed]
        # the_name = model_dir.split('/')[-1].split('.')[0]
        model_dir = model_names[0]
        the_name = model_dir.split('/')[-1].split('.')[0]
        return model_dir, the_name

    elif model_type == 'torch':
        script_dir = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
        rel_path = os.path.join(script_dir, 'data')
        abs_file_path = os.path.join(rel_path, data_name, 'seed_' + str(seed))
        model_names = glob(abs_file_path + "/*.pth")
        for i in model_names:
            the_seed = i.split("/")[-1].split("_")[3]
            the_seed = int(the_seed)
            the_dom = i.split("/")[-1].split('_') [1]
            the_dom = int(the_dom)
            the_mem = i.split("/")[-1].split("_")[4].split('mem')[1].split('.')[0]
            the_mem = int(the_mem)
            if the_mem == mem and the_seed == seed and the_dom == dom:
                model_dir = i
                the_name = i.split('/')[-1].split('.')[0]
        
        return model_dir, the_name

# This file is for : 1. draw the scatter plots of true and esitmated latent variable.
#                    2. use wand to record its performance(MCC) on both the validation set containing all validation samples and the the sub-validation set towards each domain.
""" 
    Input: 
        cfg: dict, loaded yaml file indicating the configuration
        model_dir: 'str', path of loaded model
        flag: 'str'---'all_domain' or else, indicate we will draw all domains in one figrue or not, if not: we should indicate the index of domain
        t: int, index of domain
        output_dirL 'str', path of figure to be saved
        model_type: 'str'---'torch' or 'lighting'

    Output:
        drawed_figure
""" 
def draw_all(cfg, model_dir, flag, t, output_dir, model_type):

    train_dataset, test_dataset, train_loader, val_loader, test_loader_large = gen_data_eval(cfg)

    model = load_model(cfg, model_type, model_dir, train_dataset, test_dataset)

    if flag == 'all_domain':
        for i, data in enumerate(test_loader_large):
            _, mus, _ = model(data)
    else:
        for i, data in enumerate(val_loader[t]):
            _, mus, _ = model(data)

    y = data['y']
    mus = mus.detach().numpy()
    y = y.detach().numpy()
    fir, axs = plt.subplots(mus.shape[1], y.shape[1])
    fir.set_size_inches(16.18, 10)
    for i in range(mus.shape[1]):
        for j in range(y.shape[1]):
            axs[i,j].scatter(mus[:,i], y[:,j], s=2)
    
    plt.savefig(output_dir )


def load_model(cfg, model_type, model_dir, train_dataset, test_dataset):
    if model_type == 'lighting':
                model = themodel(
                        input_dim=cfg["DATA"]["N_COMP"],
                        c_dim=cfg["DATA"]["N_COMP"]-cfg["DATA"]["N_COMP_S"],
                        s_dim=cfg["DATA"]["N_COMP_S"],
                        nclass=cfg["DATA"]["N_DOMAINS"],
                        hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                        embedding_dim=cfg['VAE']['EMBEDDING_DIM'],
                        bound=cfg['SPLINE']['BOUND'],
                        n_flow_layers=cfg['SPLINE']['N_LAYERS'],
                        count_bins=cfg['SPLINE']['BINS'],
                        order=cfg['SPLINE']['ORDER'],
                        beta=cfg['VAE']['BETA'],
                        gamma=cfg['VAE']['GAMMA'],
                        sigma=cfg['VAE']['SIGMA'],
                        vae_slope=cfg['VAE']['SLOPE'],
                        lr=cfg['VAE']['LR'],
                        use_warm_start=cfg['SPLINE']['USE_WARM_START'],
                        spline_pth=cfg['SPLINE']['PATH'],
                        decoder_dist=cfg['VAE']['DEC']['DIST'],
                        correlation=cfg['MCC']['CORR'],
                        encoder_n_layers=cfg['VAE']['ENC']['N_LAYERS'],
                        decoder_n_layers=cfg['VAE']['DEC']['N_LAYERS'],
                        optimizer=cfg['VAE']['OPTIMIZER'],
                        scheduler=cfg['VAE']['SCHEDULER'],
                        lr_factor=cfg['VAE']['LR_FACTOR'],
                        lr_patience=cfg['VAE']['LR_PATIENCE'],
                        hz_to_z=cfg["MCC"]["HZ_TO_Z"] if "HZ_TO_Z" in cfg["MCC"] else False,
                        n_mem = cfg['CL']['N_MEM'],
                        train_dataset= train_dataset,
                        test_dataset=test_dataset
                )
                model_lighting = SSA(
                        input_dim=cfg["DATA"]["N_COMP"],
                        c_dim=cfg["DATA"]["N_COMP"]-cfg["DATA"]["N_COMP_S"],
                        s_dim=cfg["DATA"]["N_COMP_S"],
                        # s_dim=3,
                        nclass=cfg["DATA"]["N_DOMAINS"],
                        hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                        embedding_dim=cfg['VAE']['EMBEDDING_DIM'],
                        bound=cfg['SPLINE']['BOUND'],
                        n_flow_layers=cfg['SPLINE']['N_LAYERS'],
                        count_bins=cfg['SPLINE']['BINS'],
                        order=cfg['SPLINE']['ORDER'],
                        beta=cfg['VAE']['BETA'],
                        gamma=cfg['VAE']['GAMMA'],
                        sigma=cfg['VAE']['SIGMA'],
                        vae_slope=cfg['VAE']['SLOPE'],
                        lr=cfg['VAE']['LR'],
                        use_warm_start=cfg['SPLINE']['USE_WARM_START'],
                        spline_pth=cfg['SPLINE']['PATH'],
                        decoder_dist=cfg['VAE']['DEC']['DIST'],
                        correlation=cfg['MCC']['CORR'],
                        encoder_n_layers=cfg['VAE']['ENC']['N_LAYERS'],
                        decoder_n_layers=cfg['VAE']['DEC']['N_LAYERS'],
                        optimizer=cfg['VAE']['OPTIMIZER'],
                        scheduler=cfg['VAE']['SCHEDULER'],
                        lr_factor=cfg['VAE']['LR_FACTOR'],
                        lr_patience=cfg['VAE']['LR_PATIENCE'],
                        hz_to_z=cfg["MCC"]["HZ_TO_Z"] if "HZ_TO_Z" in cfg["MCC"] else False,
                        )        
                # model_lighting = model_lighting.load_from_checkpoint(
                #         model_dir,
                #         input_dim=cfg["DATA"]["N_COMP"],
                #         c_dim=cfg["DATA"]["N_COMP"]-cfg["DATA"]["N_COMP_S"],
                #         s_dim=cfg["DATA"]["N_COMP_S"],
                #         nclass=cfg["DATA"]["N_DOMAINS"],
                #         hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                #         embedding_dim=cfg['VAE']['EMBEDDING_DIM'],
                #         bound=cfg['SPLINE']['BOUND'],
                #         n_flow_layers=cfg['SPLINE']['N_LAYERS'],
                #         count_bins=cfg['SPLINE']['BINS'],
                #         order=cfg['SPLINE']['ORDER'],
                #         beta=cfg['VAE']['BETA'],
                #         gamma=cfg['VAE']['GAMMA'],
                #         sigma=cfg['VAE']['SIGMA'],
                #         vae_slope=cfg['VAE']['SLOPE'],
                #         lr=cfg['VAE']['LR'],
                #         use_warm_start=cfg['SPLINE']['USE_WARM_START'],
                #         spline_pth=cfg['SPLINE']['PATH'],
                #         decoder_dist=cfg['VAE']['DEC']['DIST'],
                #         correlation=cfg['MCC']['CORR'],
                #         encoder_n_layers=cfg['VAE']['ENC']['N_LAYERS'],
                #         decoder_n_layers=cfg['VAE']['DEC']['N_LAYERS'],
                #         optimizer=cfg['VAE']['OPTIMIZER'],
                #         scheduler=cfg['VAE']['SCHEDULER'],
                #         lr_factor=cfg['VAE']['LR_FACTOR'],
                #         lr_patience=cfg['VAE']['LR_PATIENCE'],
                #         hz_to_z=cfg["MCC"]["HZ_TO_Z"] if "HZ_TO_Z" in cfg["MCC"] else False,)                
                # model.load_state_dict(model_lighting.state_dict())
                model_lighting.load_state_dict(torch.load(model_dir))
                model = model_lighting
                # model = themodel(
                #         input_dim=cfg["DATA"]["N_COMP"],
                #         c_dim=cfg["DATA"]["N_COMP"]-cfg["DATA"]["N_COMP_S"],
                #         s_dim=cfg["DATA"]["N_COMP_S"],
                #         nclass=cfg["DATA"]["N_DOMAINS"],
                #         hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                #         embedding_dim=cfg['VAE']['EMBEDDING_DIM'],
                #         bound=cfg['SPLINE']['BOUND'],
                #         n_flow_layers=cfg['SPLINE']['N_LAYERS'],
                #         count_bins=cfg['SPLINE']['BINS'],
                #         order=cfg['SPLINE']['ORDER'],
                #         beta=cfg['VAE']['BETA'],
                #         gamma=cfg['VAE']['GAMMA'],
                #         sigma=cfg['VAE']['SIGMA'],
                #         vae_slope=cfg['VAE']['SLOPE'],
                #         lr=cfg['VAE']['LR'],
                #         use_warm_start=cfg['SPLINE']['USE_WARM_START'],
                #         spline_pth=cfg['SPLINE']['PATH'],
                #         decoder_dist=cfg['VAE']['DEC']['DIST'],
                #         correlation=cfg['MCC']['CORR'],
                #         encoder_n_layers=cfg['VAE']['ENC']['N_LAYERS'],
                #         decoder_n_layers=cfg['VAE']['DEC']['N_LAYERS'],
                #         optimizer=cfg['VAE']['OPTIMIZER'],
                #         scheduler=cfg['VAE']['SCHEDULER'],
                #         lr_factor=cfg['VAE']['LR_FACTOR'],
                #         lr_patience=cfg['VAE']['LR_PATIENCE'],
                #         hz_to_z=cfg["MCC"]["HZ_TO_Z"] if "HZ_TO_Z" in cfg["MCC"] else False,
                #         train_dataset= train_dataset,
                #         test_dataset = test_dataset,
                #         n_mem = cfg['CL']['N_MEM'],
                # )
                
    elif model_type == 'torch':
                model = themodel(
                        input_dim=cfg["DATA"]["N_COMP"],
                        c_dim=cfg["DATA"]["N_COMP"]-cfg["DATA"]["N_COMP_S"],
                        s_dim=cfg["DATA"]["N_COMP_S"],
                        nclass=cfg["DATA"]["N_DOMAINS"],
                        hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                        embedding_dim=cfg['VAE']['EMBEDDING_DIM'],
                        bound=cfg['SPLINE']['BOUND'],
                        n_flow_layers=cfg['SPLINE']['N_LAYERS'],
                        count_bins=cfg['SPLINE']['BINS'],
                        order=cfg['SPLINE']['ORDER'],
                        beta=cfg['VAE']['BETA'],
                        gamma=cfg['VAE']['GAMMA'],
                        sigma=cfg['VAE']['SIGMA'],
                        vae_slope=cfg['VAE']['SLOPE'],
                        lr=cfg['VAE']['LR'],
                        use_warm_start=cfg['SPLINE']['USE_WARM_START'],
                        spline_pth=cfg['SPLINE']['PATH'],
                        decoder_dist=cfg['VAE']['DEC']['DIST'],
                        correlation=cfg['MCC']['CORR'],
                        encoder_n_layers=cfg['VAE']['ENC']['N_LAYERS'],
                        decoder_n_layers=cfg['VAE']['DEC']['N_LAYERS'],
                        optimizer=cfg['VAE']['OPTIMIZER'],
                        scheduler=cfg['VAE']['SCHEDULER'],
                        lr_factor=cfg['VAE']['LR_FACTOR'],
                        lr_patience=cfg['VAE']['LR_PATIENCE'],
                        hz_to_z=cfg["MCC"]["HZ_TO_Z"] if "HZ_TO_Z" in cfg["MCC"] else False,
                        train_dataset= train_dataset,
                        test_dataset = test_dataset,
                        n_mem = cfg['CL']['N_MEM'],
                )
                model.load_state_dict(torch.load(model_dir))
        
    return model

def gen_data_eval(cfg):
    train_data, test_data = gen_da_data_ortho(
        Nsegment=cfg["DATA"]["N_DOMAINS"], 
        Ncomp=cfg["DATA"]["N_COMP"],
        Ncomp_s=cfg["DATA"]["N_COMP_S"],
        Nlayer=cfg["DATA"]["N_LAYERS"],
        var_range_l=cfg["DATA"]["VAR_RANGE_L"],
        var_range_r=cfg["DATA"]["VAR_RANGE_R"],
        mean_range_l=cfg["DATA"]["MEAN_RANGE_L"],
        mean_range_r=cfg["DATA"]["MEAN_RANGE_R"],
        NsegmentObs_train=cfg["DATA"]["N_TRAIN_SAMPLES_DOMAIN"],
        NsegmentObs_test=cfg['DATA']['N_TEST_SAMPLES_DOMAIN'],
        Nobs_test=cfg["DATA"]["N_TEST_SAMPLES"],
        varyMean=cfg["DATA"]["VARY_MEAN"], 
        seed=cfg["SEED"],
        mixtures=cfg["DATA"]["MIXTURES"],
        n_modes_range_l=cfg["DATA"]["N_MODES_RANGE_L"],
        n_modes_range_r=cfg["DATA"]["N_MODES_RANGE_R"],
        p_domains_range_l=cfg["DATA"]["P_DOMAINS_RANGE_L"],
        p_domains_range_r=cfg["DATA"]["P_DOMAINS_RANGE_R"],
        linear_mixing_first=cfg["DATA"]["LINEAR_MIXING_FIRST"],
        save_all_datasets=cfg["DATA"]["SAVE_ALL_DATASETS"] if "SAVE_ALL_DATASETS" in cfg["DATA"] else False,
        triangle=cfg['DATA']['TRIANGLE']
        )

    output_train, output_test = split_data(train_data, test_data, cfg["DATA"]["N_DOMAINS"], cfg["DATA"]["N_TRAIN_SAMPLES_DOMAIN"], cfg['DATA']['N_TEST_SAMPLES_DOMAIN'])
    test_dataset_large = DANS(test_data)
    train_dataset = []
    test_dataset = []
    train_loader = []
    val_loader = []
    for i in range(len(output_train)):

        train_dataset.append(DANS(output_train[i])) 
        test_dataset.append(DANS(output_test[i]))

        train_loader.append(DataLoader(train_dataset[i], 
                                    batch_size=cfg['VAE']['TRAIN_BS'], 
                                    pin_memory=cfg['VAE']['PIN'],
                                    num_workers=cfg['VAE']['CPU'],
                                    drop_last=False,
                                    shuffle=False))
        val_loader.append(DataLoader(test_dataset[i], 
                                    batch_size=len(test_dataset[i]), 
                                    pin_memory=cfg['VAE']['PIN'],
                                    num_workers=cfg['VAE']['CPU'],
                                    drop_last=True,
                                    shuffle=False))

    test_loader_large = DataLoader(test_dataset_large,
                                batch_size=len(test_dataset_large),
                                pin_memory=cfg['VAE']['PIN'],
                                num_workers=cfg['VAE']['CPU'],
                                drop_last=True,
                                shuffle=False)

    return train_dataset, test_dataset, train_loader, val_loader, test_loader_large
    # %%
   