# %%
# %%
from tokenize import group
import torch
import random
import argparse
import numpy as np
from pathlib import Path
import ipdb as pdb
import os, pwd, yaml
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from dataset import DANS
import warnings
from utils import load_yaml
import wandb
warnings.filterwarnings('ignore')
from gen_dataset import gen_da_data_ortho
import quadprog
from components.beta import BetaVAE_MLP
# from GEMlearning import 
from sssa import SSA
import ipdb as pdb
import torch.nn as nn
import torch.nn.init as init
import pytorch_lightning as pl
import torch.distributions as D
from torch.nn import functional as F
from components.beta import BetaVAE_MLP
from components.transforms import NormalizingFlow
from metrics.correlation_nonlinear import compute_mcc_nonlinear
from metrics.block import compute_r2
import random as rd
import quadprog
import torch.optim as opt
from utils import split_data

# %%
script_dir = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
rel_path = os.path.join('configs', 'test.yaml')
abs_file_path = os.path.join(script_dir, rel_path)
cfg = load_yaml(abs_file_path)

torch.manual_seed(cfg['SEED'])
random.seed(cfg['SEED'])
np.random.seed(cfg['SEED'])


# %%
"""
Now we have train_loader and val_loader which are lists
train_loader: a list with length n_domains: each element is a dataloader containing : x training data [n_samples_training, n_features] 
                                                                                      c training labels [n_samples_training, 0 or 1 .... 8]
                                                                                      z real data befor mixing [n_samples_training, n_features]
val_loader is the same, the only difference is the size 
"""
data_name = 'n_' + str(cfg['DATA']['N_COMP']) + '_c_' + str(cfg['DATA']['N_COMP_S']) + '_d_' + str(cfg['DATA']['N_DOMAINS']) + '_iid_'+ str(cfg['DATA']['IID_DOMAINS']) + '_trangle_' + str(cfg['DATA']['TRIANGLE']) + '_mixed_' + str(cfg['DATA']['MIXED_GAUSSIANS']) 
abs_data_path = os.path.join(script_dir, 'data') 
data_path = os.path.join(abs_data_path, data_name)
data_seed_path = os.path.join(data_path, 'seed_' + str(cfg['SEED']))
output_dir = os.path.join(data_path, data_seed_path )

output_dir_data = os.path.join(output_dir, 'data')
output_dir_model = os.path.join(output_dir, 'model')
output_dir_fig = os.path.join(output_dir, 'fig')

train_dir = os.path.join(output_dir_data, 'train_data.pt')
test_dir = os.path.join(output_dir_data, 'test_data.pt')

train_data = torch.load(train_dir)
test_data = torch.load(test_dir)

output_train, output_test = split_data(train_data, test_data, cfg["DATA"]["N_DOMAINS"], cfg["DATA"]["N_TRAIN_SAMPLES_DOMAIN"], cfg['DATA']['N_TEST_SAMPLES_DOMAIN'])

print('we are loading', train_dir)

# %%
# %%


train_dataset = []
test_dataset = []
train_loader = []
val_loader = []
n_domain_used = cfg['CL']['N_DOMAIN_USED']
n_test_s_domain = cfg['DATA']['N_TEST_SAMPLES_DOMAIN']

test_data_large = {'x': test_data['x'][0:n_domain_used*n_test_s_domain], 'y':test_data['y'][0:n_domain_used*n_test_s_domain], 'c':test_data['c'][0:n_domain_used*n_test_s_domain]}
test_dataset_large = DANS(test_data_large)

for i in range(n_domain_used):

    train_dataset.append(DANS(output_train[i])) 
    test_dataset.append(DANS(output_test[i]))

    train_loader.append(DataLoader(train_dataset[i], 
                                batch_size=cfg['VAE']['TRAIN_BS'], 
                                pin_memory=cfg['VAE']['PIN'],
                                num_workers=cfg['VAE']['CPU'],
                                drop_last=False,
                                shuffle=False))
    val_loader.append(DataLoader(test_dataset[i], 
                                batch_size=len(test_dataset[i]), 
                                pin_memory=cfg['VAE']['PIN'],
                                num_workers=cfg['VAE']['CPU'],
                                drop_last=True,
                                shuffle=False))


test_loader_large = DataLoader(test_dataset_large,
                            batch_size=len(test_dataset_large),
                            pin_memory=cfg['VAE']['PIN'],
                            num_workers=cfg['VAE']['CPU'],
                            drop_last=True,
                            shuffle=False)


from themode_withspline2 import themodel
trainmodel = themodel(
        input_dim=cfg["DATA"]["N_COMP"],
        c_dim=cfg["DATA"]["N_COMP"]-cfg["DATA"]["N_COMP_S"],
        s_dim=cfg["DATA"]["N_COMP_S"],
        nclass=cfg["DATA"]["N_DOMAINS"],
        # nclass=7,
        hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
        embedding_dim=cfg['VAE']['EMBEDDING_DIM'],
        bound=cfg['SPLINE']['BOUND'],
        n_flow_layers=cfg['SPLINE']['N_LAYERS'],
        count_bins=cfg['SPLINE']['BINS'],
        order=cfg['SPLINE']['ORDER'],
        beta=cfg['VAE']['BETA'],
        gamma=cfg['VAE']['GAMMA'],
        sigma=cfg['VAE']['SIGMA'],
        vae_slope=cfg['VAE']['SLOPE'],
        lr=cfg['VAE']['LR'],
        use_warm_start=cfg['SPLINE']['USE_WARM_START'],
        spline_pth=cfg['SPLINE']['PATH'],
        decoder_dist=cfg['VAE']['DEC']['DIST'],
        correlation=cfg['MCC']['CORR'],
        encoder_n_layers=cfg['VAE']['ENC']['N_LAYERS'],
        decoder_n_layers=cfg['VAE']['DEC']['N_LAYERS'],
        optimizer=cfg['VAE']['OPTIMIZER'],
        scheduler=cfg['VAE']['SCHEDULER'],
        lr_factor=cfg['VAE']['LR_FACTOR'],
        lr_patience=cfg['VAE']['LR_PATIENCE'],
        hz_to_z=cfg["MCC"]["HZ_TO_Z"] if "HZ_TO_Z" in cfg["MCC"] else False,
        train_dataset= train_dataset,
        test_dataset = test_dataset,
        n_mem = cfg['CL']['N_MEM'],
        save_all = cfg['CL']['SAVE_ALL'],
        seed = cfg['SEED'],
        save_path= output_dir_model,
)
# %%
run = wandb.init(
    project = "GEM_Triangle",
    name =  'seed_' + str(cfg['SEED']) +  data_name, 
    config = cfg
)
wandb.config
print(run.name)


max_epochs = cfg["CL"]["MAX_EPOCHS"]

trainmodel.training_step(train_loader = train_loader, max_epochs = max_epochs, val_loader_large=test_loader_large)
trainmodel.valdation_step(val_loader = val_loader)
save_str = os.path.join(output_dir_model, 'gem', data_name, 'model.pt')
torch.save(trainmodel.state_dict(), save_str)

artifact = wandb.Artifact('model', type='model')
artifact.add_file(save_str)
run.log_artifact(artifact)
run.join()




# %%
