#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=1

import warnings
warnings.filterwarnings('ignore')

import sys
from pathlib import Path
sys.path.append(str(Path('.').absolute().parent))

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from time import process_time as time

# internal imports
from src.gwgan.utils import *
from src.gwgan.data import *
from src.gwgan.model_mlp import Generator, GeneratorSmall
from src.gwgan.model_mlp import weights_init_generator, weights_init_zeros
from src.sse_sgw_utils import sgw_gpu, ssegromov_gpu, distrib_sse, min_sse, distributional_min_sse

import torch.nn as nn
from torch import optim

# ##### Example with the numpy implementation
torch.set_default_tensor_type(torch.FloatTensor)

class TransformLatenttoOrig(nn.Module):
    def __init__(self, dim_latent,dim_orig):
        super(TransformLatenttoOrig, self).__init__()
        self.dim_latent = dim_latent
        self.dim_orig = dim_orig
        self.dim_hidden = 100

        self.net = nn.Sequential(nn.Linear(self.dim_latent,self.dim_hidden),
                                 nn.Sigmoid(),
                                 nn.Linear(self.dim_hidden,self.dim_hidden),
                                 nn.Sigmoid(),
                                 nn.Linear(self.dim_hidden, self.dim_orig),
                                 nn.Sigmoid(),
                                 )
    def forward(self, input):
        out =self.net(input)
        return out/torch.sqrt(torch.sum(out**2,dim=1,keepdim=True))    


class TransformNet(nn.Module):
    def __init__(self, size):
        super(TransformNet, self).__init__()
        self.size = size
        self.hidden = 50
        self.weight = 0.
        self.net = nn.Sequential(nn.Linear(self.size,self.hidden),
                                 nn.LeakyReLU(0.2, True),
                                 nn.Linear(self.hidden,self.hidden),
                                 nn.LeakyReLU(0.2, True),
                                 nn.Linear(self.hidden,self.hidden),
                                 nn.LeakyReLU(0.2, True),
                                 nn.Linear(self.hidden,  self.size),
                                 )
    def forward(self, input):
        out = input + self.net(input)*self.weight
        return out/torch.sqrt(torch.sum(out**2,dim=1,keepdim=True))
    


# get arguments
FUNCTION_MAP = {'4mode': gaussians_4mode,
                '5mode': gaussians_5mode,
                '5mode3d': gaussians_5mode,
                '8mode': gaussians_8mode,
                '3d_4mode': gaussians_3d_4mode
                }
# plotting preferences
matplotlib.rcParams['font.sans-serif'] = 'Times New Roman'
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.size'] = 10


parser = argparse.ArgumentParser()

# general arguments
parser.add_argument('--train_iter', type=int, default=30000)
parser.add_argument('--data', default='mnist')
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--modes', default ='3d_4mode')
parser.add_argument('--loss_name', default ='sgw', choices=['sgw', 'distrib_sse','min_sse','distrib_min_sse'])
parser.add_argument('--lam', type=int, default=5)
parser.add_argument('--network', default='large')


args = parser.parse_args()

# system preferences
torch.set_default_dtype(torch.float)
seed = np.random.randint(100)
np.random.seed(seed)
torch.manual_seed(seed)

# settings!
loss_name = args.loss_name
train_iter = args.train_iter
lam_ap = args.lam

n_samples = 3000
batch_size = 300

z_dim = 256
lr = 0.001
nproj = 2000   # nb projection for non-optimized
nproj_dist = 10
nproj_dist_d = 10
nb_inner_iter = 50
if args.network=='large':
    smallgenerator = False
else :
    smallgenerator = True

plot_every = n_samples//batch_size
dim_latent = 50
modes = args.modes

np.random.seed(0)
torch.manual_seed(0)
model = f"{modes}_dist_{loss_name}_{batch_size:d}_{n_samples:d}_{nproj}_{nproj_dist}_{nproj_dist_d}_lr{lr:2.5f}_iter{train_iter:d}_lap{lam_ap:2.2f}_small{smallgenerator}/"
print(model)

simulation = FUNCTION_MAP[modes]

# data simulation
data, y = simulation(n_samples)
data, y = data.astype(float), y.astype(float)
data_size = len(data)
data = np.concatenate((data, data[:batch_size, :]), axis=0)
y = np.concatenate((y, y[:batch_size]), axis=0)

save_fig_path = 'result/' + model
if not os.path.exists(save_fig_path):
    os.makedirs(save_fig_path)


#-----------------------------------------------------------------------------
#                       figure 
#-----------------------------------------------------------------------------

real = data[:1000]
real_y = y[:1000]
if modes == '3d_4mode':
    fig1 = plt.figure(figsize=(4, 4))
    df = pd.DataFrame({'x1': real[:, 0],
                       'x2': real[:, 1],
                       'x3': real[:, 2],
                       'in': real_y})
    ax1 = fig1.add_subplot(111, projection='3d')
    ax1.scatter3D(df.x1, df.x2, df.x3, c='#1B263B')
    ax1.set_zlim([-4, 4])
    view_1 = (25, -135)
    view_2 = (25, -45)
    init_view = view_2
    ax1.view_init(*init_view)
    ax1.set_zlabel('x3')

    ax1.tick_params(axis='both',size=7,labelsize=7)
    ax1.set_xlim([-4, 4])
    ax1.set_ylim([-4, 4])
    ax1.set_title(r'target')
    fig1.tight_layout()
    fig1.savefig(save_fig_path + '/real.pdf')
    
else:
    fig1 = plt.figure(figsize=(4, 4))
    df = pd.DataFrame({'x1': real[:, 0],
                       'x2': real[:, 1],
                       'in': real_y})
    ax1 = fig1.add_subplot(111)
    sns.kdeplot(df.x1, df.x2, shade=True, cmap='Blues', n_levels=20, legend=False)
    ax1.tick_params(axis='both',size=7,labelsize=7)

    ax1.set_xlim([-3, 3])
    ax1.set_ylim([-3, 3])
    ax1.set_title(r'target')
    fig1.tight_layout()
    fig1.savefig(save_fig_path + '/real.pdf')
# sample for plotting
z_ex = sample_z(1000, z_dim)
# set iterator for plot numbering
i = 0
#--------------------------------------------------------------------------------
if modes == '3d_4mode':
    dim_t = 2
    dim_s = 3
elif modes == '5mode3d':
    dim_t = 3
    dim_s = 2
else:
    dim_s = 2
    dim_t = 2

# define networks and parameters
if smallgenerator: 
    generator = GeneratorSmall(output_size=dim_t)
else:
    generator = Generator(output_size=dim_t)

generator.apply(weights_init_generator)
g_optimizer = torch.optim.Adam(generator.parameters(), lr)


loss_history = []
time_history = []
device = 'cpu'


ps = None
pt = None


for it in range(train_iter):
    train_c = ((it + 1) % (n_samples/batch_size) == 0)
    if train_c:
        ind = torch.randperm(data.shape[0])
        data = data[ind]
    start_idx = it * batch_size % data_size
    X_mb = data[start_idx:start_idx + batch_size, :]
    y_mb = y[start_idx:start_idx + batch_size]

    # sample points from latent space
    z = sample_z(batch_size, z_dim)

    # get data mini batch
    x = torch.from_numpy(X_mb[:batch_size, :]).float()
    y_s = y_mb[:batch_size]

    g = generator.forward(z)

    device = 'cpu'
    fs = TransformNet(dim_s).to(device)
    ft = TransformNet(dim_t).to(device)
    fs.apply(weights_init_zeros)
    ft.apply(weights_init_zeros)

    fs_optim = optim.Adam(fs.parameters(), lr=0.005, betas=(0.5, 0.999))
    ft_optim = optim.Adam(ft.parameters(), lr=0.005, betas=(0.5, 0.999))

    ftd = TransformLatenttoOrig(dim_latent,dim_t).to(device)
    fsd = TransformLatenttoOrig(dim_latent,dim_s).to(device)
        
        
    ftd_optim = optim.Adam(ftd.parameters(), lr=0.001, betas=(0.5, 0.999))
    fsd_optim = optim.Adam(fsd.parameters(), lr=0.001, betas=(0.5, 0.999))

    fdistrib = TransformNet(dim_latent).to(device)
    fdistrib_optim = optim.Adam(fdistrib.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.005)
    tic = time()
    if loss_name == 'distrib_sse':
        loss =  distrib_sse(g,x,ft,fs,ft_optim,fs_optim,nproj=nproj_dist,max_iter=20,lam=0.00001)
    elif loss_name == 'min_sse':
        loss = min_sse(g,x,ftd,fsd,ftd_optim,fsd_optim,dim_latent,nproj=nproj,max_iter=200)
    elif loss_name == 'distrib_min_sse':
        loss = distributional_min_sse(g, x, fdistrib, ftd, fsd, fdistrib_optim, ftd_optim, fsd_optim, 
                               dim_latent, nproj=nproj_dist, num_epochs = nb_inner_iter, num_sup_iter = 2
                               , num_inf_iter = 2 ,verbose=False,lam_ap=lam_ap)
    elif loss_name == 'sgw': 
        loss =  sgw_gpu(x,g,nproj=nproj,device='cpu') + 0.01*torch.sum((g**2))
    elif loss_name == 'sse_gromow':
        loss =  ssegromov_gpu(g,x,nproj=nproj,device='cpu')
    time_history.append(time() - tic)
    
    g_optimizer.zero_grad()
    loss.backward()
    g_optimizer.step()
    print('iter:', it+1, loss_name, 'loss:', loss.item())

    loss_history.append(loss.item())

    # plotting
    if (it+1) % plot_every == 0:
        # get generator example
        if dim_t == 2:
            g_ex = generator.forward(z_ex)
            g_ex = g_ex.detach().numpy()
            # plotting
            fig2 = plt.figure(figsize=(4, 4))
            ax2 = fig2.add_subplot(111)
            result = pd.DataFrame({'x1': g_ex[:, 0],
                                   'x2': g_ex[:, 1]})
            sns.kdeplot(result.x1, result.x2,
                        shade=True, cmap='Blues', n_levels=20, legend=False)
            # ax2.set_title(r'$g_\theta(Z)$')
            ax2.set_title(r'iteration {}'.format((it+1)),fontsize=14)
            plt.xticks(fontsize=7)
            plt.yticks(fontsize=7)
            plt.tight_layout()
            fig2.savefig(os.path.join(save_fig_path, 'gen_{}.pdf'.format(
                         str(i).zfill(3))))
            i += 1

        elif dim_t == 3:
            g_ex = generator.forward(z_ex)
            g_ex = g_ex.detach().numpy()
            # plotting
            fig2 = plt.figure(figsize=(4, 4))
            ax2 = fig2.add_subplot(111)
            result= pd.DataFrame({'x1': g_ex[:, 0],
                       'x2': g_ex[:, 1],
                       'x3': g_ex[:, 2]})
            ax2 = fig2.add_subplot(111, projection='3d')
            ax2.scatter3D(result.x1, result.x2, result.x3, c='#1B263B')
            view_1 = (25, -135)
            view_2 = (55, -45)
            init_view = view_2
            ax2.view_init(*init_view)
            ax2.set_zlabel('x3')        
            ax2.tick_params(axis='both',size=7,labelsize=7)
            ax2.set_title(r'iteration {}'.format((it+1)),fontsize=14)
            fig2.tight_layout()
            fig2.savefig(os.path.join(save_fig_path, 'gen_{}.pdf'.format(
                         str(i).zfill(3))))
            i += 1

                    


#----------------------------------------------------------------------------
np.savez(save_fig_path+'resultat.npz', loss_history=np.array(loss_history), time_history=np.array(time_history))



