from __future__ import print_function
import GPUtil
import numpy as np
import torch.utils.data
import torch
import sklearn.datasets as sk_data
from sklearn.mixture import GaussianMixture as GM
from src.train import train
import src.utils.log_utils as LLU
import src.datamodules.generate_data as g_data
import src.utils.pytorch_utils as PTU
import src.datamodules.data_utils as DTU
from src.datamodules.record_mean_cov import select_mean_and_cov_gauss
from configs.cfg import CfgGMM
cfg =CfgGMM()
gpus_choice = GPUtil.getFirstAvailable(
    order='random', maxLoad=0.6, maxMemory=0.6, attempts=5, interval=900, verbose=False)

PTU.set_gpu_mode(True, gpus_choice[0])
# print(gpus_choice[0])
if cfg.type_data == 'GM':
    cfg.MEAN, cfg.COV, cfg.INPUT_DIM, cfg.NUM_GMM_COMPONENT = DTU.get_gmm_param(
        cfg.TRIAL, num_component=cfg.NUM_GMM_COMPONENT, seed=cfg.seed)
    mean_Q = PTU.numpy2torch(cfg.MEAN[1]).cuda(PTU.device)
    cov_Q = PTU.numpy2torch(cfg.COV[1]).cuda(PTU.device)
    weights = torch.ones(mean_Q.shape[0]).cuda(PTU.device) / mean_Q.shape[0]

elif cfg.type_data == 'Gauss2Gauss':
    cfg.MEAN, cfg.COV = select_mean_and_cov_gauss(cfg.INPUT_DIM, int(cfg.TRIAL))
    mean_Q = PTU.numpy2torch(cfg.MEAN).cuda(PTU.device)
    inv_cov = np.linalg.inv(cfg.COV)
    inv_cov_Q = PTU.numpy2torch(inv_cov).cuda(PTU.device)
    cfg.T_linear = 1

elif cfg.type_data == 'two_moons':
    samples, _ = sk_data.make_moons(10000, noise=0.05)
    gmm16 = GM(n_components=16, covariance_type='full', random_state=0).fit(samples)
    mean_Q = PTU.numpy2torch(gmm16.means_).cuda(PTU.device)
    cov_Q = PTU.numpy2torch(gmm16.covariances_).cuda(PTU.device)
    weights = PTU.numpy2torch(gmm16.weights_).cuda(PTU.device)
    cfg.NUM_GMM_COMPONENT = [1, 0]
    cfg.INPUT_DIM, cfg.NUM_GMM_COMPONENT[1] = DTU.get_gm_dim_component_one(mean_Q)

results_save_path, image_save_path, P_save_path, results = LLU.init_path(cfg)

density_q = [mean_Q, inv_cov_Q] if cfg.type_data == 'Gauss2Gauss' else [mean_Q, cov_Q, weights]
if cfg.debug_h:
    P_k = g_data.gaussian_data(cfg.N_TRAIN_SAMPLES, cfg)[:, :, 1]
    mean_mu = P_k.mean(axis=0)
    cov_mu = (P_k - mean_mu).T @ (P_k -
                                  mean_mu) / (P_k.shape[0] - 1)
    inv_cov_mu = torch.inverse(cov_mu)
    assert inv_cov_mu.shape[0] == cfg.INPUT_DIM
else:
    mean_mu = torch.zeros(cfg.INPUT_DIM).cuda(PTU.device)
    inv_cov_mu = torch.eye(cfg.INPUT_DIM) / cfg.mu_var
density_mu = [mean_mu, inv_cov_mu.cuda(PTU.device)]

if cfg.type_data == 'Gauss2Gauss':
    total_data = g_data.importa_samp_data_gauss(cfg)
else:
    if cfg.debug_h:
        total_data = g_data.importa_debug_data_gmm(cfg, mean_mu, cov_mu)
    else:
        total_data = g_data.importa_samp_data_gmm(cfg)

# if cfg.bregman_loss:
#     from src.models.pl_system_Tmap_r import GM_nabla_Tmap_r as light_system
# else:
if cfg.map_type == 'T':
    if cfg.fb_method:
        from src.models.kl_fb_map import GM_2step_Tmap as light_system
    else:
        if cfg.type_data == 'GM' or cfg.type_data == 'two_moons':
            from src.models.kl_Tmap import GM_Tmap as light_system
        elif cfg.type_data == 'Gauss2Gauss':
            from src.models.kl_Tmap import Gaussian_Tmap as light_system
else:
    if cfg.type_data == 'GM':
        from src.models.kl_gmap import GM_nabla_gmap as light_system
    elif cfg.type_data == 'Gauss2Gauss':
        from src.models.kl_gmap import Gaussian_nabla_gmap as light_system

model = light_system(
    cfg, density_q, density_mu, total_data, P_save_path, image_save_path=image_save_path)

train(model, cfg, results_save_path, device_number=gpus_choice[0])
