import sys
import matplotlib.pylab as plt
import torch
from tqdm import tqdm
import os
import gpytorch
import numpy as np
from numpy.random import RandomState
import random
sys.path.append('..')
from dataset_mine import load_dataset
from models.NG_GPLVM import NG_GPLVM
from models.NG_MVLVM import NG_MVLVM
from visualizer import Visualizer
from metrics import (knn_classify,
                     mean_squared_error,
                     r_squared)
def save_models(model, optimizer, epoch, losses, result_dir, data_name, save_model=True):
    '''

    Parameters
    ----------
    model
    optimizer
    epoch
    losses
    result_dir  :           result saving path
    data_name   :           data name
    jj          :           number of experiment repetition
    save_model  :           indication if to saving model

    Returns
    -------

    '''
    state = {'model': model.state_dict(),
             'optimizer': optimizer.state_dict(),
             'epoch': epoch,
             'losses': losses}
    if save_model:
        log_dir = result_dir + f"{data_name}_epoch{epoch}.pt"
        torch.save(state, log_dir)


random_seed = 8
def reset_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

reset_seed(random_seed)
#device = "cuda:1"
device = 'cpu'


""" ##-------------------------------- data loader ------------------------------------------------"""
# Load Dataset
rng = RandomState(random_seed)

ds2 = load_dataset(rng, 's-curve2', 'gaussian')
Y2 = ds2.Y / np.linalg.norm(ds2.Y)

ds1 = load_dataset(rng, 's-curve', 'gaussian')
Y1 = ds1.Y / np.linalg.norm(ds1.Y)

""" ##-------------------------------- Parameters Settings ------------------------------------------------"""
setting_dict = {}
setting_dict['num_m'] = [2,2]            # if num_m = 1, it is using SE kernel
setting_dict['num_sample_pt'] = 50
setting_dict['num_total_pt'] =[setting_dict['num_m'][0] * setting_dict['num_sample_pt'],
                               setting_dict['num_m'][1] * setting_dict['num_sample_pt']]
setting_dict['num_batch'] = 1
setting_dict['lr_hyp'] = .01
setting_dict['iter'] = 10000
setting_dict['num_repexp'] = 1
setting_dict['kl_option'] = True  # if adding X regularization in loss function
setting_dict['noise_err'] = 100.0
setting_dict['latent_dim'] = ds1.latent_dim
setting_dict['N'] = ds1.Y.shape[0]

Y =[Y1,Y2]

GPLVM_model = NG_MVLVM(setting_dict['num_batch'],
                        setting_dict['num_sample_pt'],
                        setting_dict,
                        Y, device=device, ifPCA=True).to(device)
optimizer = torch.optim.Adam(GPLVM_model.parameters(), lr=setting_dict['lr_hyp'])

epochs_iter = tqdm(range(setting_dict['iter']+1), desc="Epoch")
for i in epochs_iter:

    GPLVM_model.train()

    optimizer.zero_grad()
    losstotal = GPLVM_model.compute_loss(batch_y=Y, kl_option=setting_dict['kl_option'])
    losstotal.backward()
    optimizer.step()

    if i%500==0:
        print(f'\nELBO: {losstotal.item()}')
        print(f"X_KL: {GPLVM_model._kl_div_qp().item()}")

        F, K = GPLVM_model.f_eval(batch_y_view=Y1,view=0,x_star=None)
        F = (F - 5).cpu().detach().numpy()  # because data generation adds mean=5,  shape: N_star * obs_dim
        K = K.cpu().detach().numpy()
        model_name = f"MngsmRFLVM_Y1_RBF_m{setting_dict['num_m']}"
        res_dir = f'/mnt/NG-MV-RFLVM/MVresults/S/{model_name}/'
        viz = Visualizer(res_dir + 'figures', ds1)
        viz.plot_iteration(i + 1,  Y=0,  F=0,  K=K, X=GPLVM_model.mu_x.cpu().detach().numpy())

        F, K = GPLVM_model.f_eval(batch_y_view=Y2, view=1, x_star=None)
        F = (F - 5).cpu().detach().numpy()  # because data generation adds mean=5,  shape: N_star * obs_dim
        K = K.cpu().detach().numpy()
        model_name = f"MngsmRFLVM_Y2_RBF_m{setting_dict['num_m']}"
        res_dir = f'/mnt/NG-MV-RFLVM/MVresults/S/{model_name}/'
        viz = Visualizer(res_dir + 'figures', ds2)
        viz.plot_iteration(i + 1, Y=0, F=0, K=K, X=GPLVM_model.mu_x.cpu().detach().numpy())

        #viz.plot_F(i+1, F)
        save_models(model=GPLVM_model, optimizer=optimizer, epoch=i, losses=losstotal,
                    result_dir=res_dir, data_name='s-curve2', save_model=False)


        if ds1.has_true_X:
            r2_X = r_squared(GPLVM_model.mu_x.cpu().detach().numpy(), ds1.X)
            print(f'R2 X: {r2_X}')

        print("\n")

