import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import nets

import numpy as np
import math

from tqdm import tqdm 
import yaml

from dro_mev_functions.DRO_MEV_nn import *
from dro_mev_functions.DRO_MEV_train import *
from dro_mev_functions.DRO_MEV_util import *

np.random.seed(12)
torch.manual_seed(12)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(
                    prog = 'MEVDRONet',
                    description = 'Computes adversarial risk for some data.')
    parser.add_argument('filename')
    args = parser.parse_args()

    with open(args.filename, 'r') as f:
        params = yaml.safe_load(f)

    d = params['d'] #2

    n_epochs = params['n_epochs'] #5001
    n_lam    = params['n_lam'] #20

    width  = params['width'] #500
    n_data = params['n_data'] #100
    block_size = params['block_size']

    n_max = params['n_max'] #10
    rate = params['rate'] #10

    use_softmax = params['use_softmax'] #False
    experiment  = params['experiment'] #'evd'

    risk = params['risk'] #'cvar'

    n_eps  = params['n_eps'] #20
    n_runs = params['n_runs'] #1

    eps_max = params['eps_max'] #0.5

    data_file = params['data_file'] 
    cost_norm = params['cost_norm']

    fit_margins = params['fit_margins']

    gen_p0 = params['gen_p0']
    alpha  = params['alpha']

    try:
        eps_coef = params['eps_coef']
    except:
        eps_coef = None

    try:
        pretrain = params['pretrain']
    except:
        pretrain = False

    c = lambda z, x : (z - x).norm(cost_norm, dim=-1)

    x_eval = torch.rand(d)
    synthetic_rate = params['synthetic_rate'] #0

    import pickle
    with open('{}.p'.format(data_file),'rb') as f:
        data = pickle.load(f)

    true_risk = (((data < 1 / x_eval).min(-1)[0]).float()).mean()

    print('True Risk: {}'.format(true_risk))

    losses = []

    import os
    try:
        if use_softmax:
            save_path  = data_file+'_{}_{}_softmax_gen_eps{}/'.format(experiment, risk, eps_max)
        else:
            if eps_coef is not None:
                save_path  = data_file+'_{}_{}_gen_eps{}_data{}_blocksize{}/'.format(experiment, risk, eps_coef, n_max, block_size, eps_coef)
            else:
                save_path  = data_file+'_{}_{}_gen_eps{}_data{}_blocksize{}/'.format(experiment, risk, eps_max, n_max, block_size)
        if not os.path.isdir(save_path):
            os.mkdir(save_path)
    except OSError as error:
        print(error)

    print('Saving in {}'.format(save_path))

    if data.shape[-1] == 2:
        plt.scatter(data[:,0], data[:,1], alpha=0.3)
        plt.title(data_file)
        plt.xlabel(r'$X_1$')
        plt.ylabel(r'$X_0$')
        plt.savefig(save_path + 'original_data.pdf')
        plt.close('all')

        cfg = nets.CFGEstimator(data)
        x_ = torch.linspace(0,1,100)
        x = torch.stack((x_, 1-x_),1)
        plt.plot(x_, cfg(x))
        plt.savefig(save_path + 'original_data_cfg.pdf')
        plt.close('all')

    if eps_max > 0:
        use_eps = True
        loop_vars = torch.linspace(0,eps_max,n_eps)
    else:
        use_eps = False
        n_data_min = params['n_data_min']
        loop_vars = torch.linspace(n_data_min, n_data, n_eps).int()

    n_loop = loop_vars.shape[0]

    adv_loss = torch.zeros((n_loop, n_runs))
    pop_loss = torch.zeros((n_loop, n_runs))

    adv_loss_mae = torch.zeros((n_loop, n_runs))
    pop_loss_mae = torch.zeros((n_loop, n_runs))

    adv_risk = torch.zeros((n_loop, n_runs))
    pop_risk = torch.zeros((n_loop, n_runs))
    true_risk_np = torch.zeros((n_loop, n_runs))

    E_p0 = torch.zeros((n_loop, n_runs))

    eps_N = torch.zeros(n_loop, n_runs)

    ep = []

    Fx = None

    for loop_idx, loop_var in enumerate(loop_vars):
        import ot
        n_data = 100

        for run in range(n_runs):

            if use_eps:
                # if we are looping over epsilon, set that as looping var
                eps    = loop_var
            else:
                n_data = loop_var.int().item()

            sampled_data = data[np.random.choice(data.shape[0],n_data), :]


            if block_size > 0:
                if block_size == 100:
                    sampled_data = sampled_data.reshape(5 * (1+loop_idx), -1, d).max(0)[0]
                else:
                    sampled_data = sampled_data.reshape(block_size, -1, d).max(0)[0]
                n_data = sampled_data.shape[0]

            cfg = nets.CFGEstimator(sampled_data)

            N_train = n_data

            spec = sampled_data / sampled_data.sum(-1, keepdim=True)
            pp = ( - torch.log(torch.rand(n_max, sampled_data.shape[0], 1))).cumsum(0)

            if gen_p0:
                net_p0 = P0Module_pp(32, 2, d, d, act=nn.LeakyReLU())  # fit the P0 network
                fit_p0_pp(net_p0, sampled_data, save_path)
                net_p0.eval()


                N_E = 1000

            with torch.no_grad():
                Y_n = net_p0.sample_z(( N_E, 1000, d )).detach()

            v_n = (Y_n * x_eval).max(-1)[0]
            a_n = (-torch.rand_like(v_n).log()).cumsum(1)
            p0  = 1 - ( -(v_n).mean() ).exp()

            with torch.no_grad():
                plt.scatter(a_n, v_n)
                plt.savefig(save_path + 'vn_an.pdf')
                plt.close('all')

                plt.scatter(Y_n[:,0], Y_n[:,1])
                plt.savefig(save_path + 'spec.pdf')
                plt.close('all')

            print('Risk P0   = {}'.format(p0))

            pop, adv, adv2 = train_pp(net_p0, v_n, a_n, x_eval, Y_n, net_p0, eps, act=id_act, n_epochs=n_epochs, save_path=save_path, experiment=experiment)

            true_risk_np[loop_idx, run] = true_risk
            adv_risk[loop_idx, run] = adv
            pop_risk[loop_idx, run] = pop

        losses.append((loop_var, pop, adv, true_risk))


    pickle_dict = {'losses' : losses,
            'E_p0' : E_p0,
            'true_risk' : true_risk_np,
            'p0_risk' : pop_risk,
            'adv_risk': adv_risk,
            'pop_loss': pop_loss, 
            'adv_loss': adv_loss, 
            'pop_loss_mae': pop_loss_mae,
            'adv_loss_mae': adv_loss_mae}

    with open(save_path + '/stats{}_{}.p'.format(eps_max, n_data), 'wb') as f:
        pickle.dump(pickle_dict, f)

    import numpy as np

    losses = np.array(losses)

    plt.plot(losses[:,0], pop_loss.mean(1), label=r'$\mathbb{P}_0$')
    plt.fill_between(losses[:,0], pop_loss.mean(1) - pop_loss.std(1), pop_loss.mean(1) + pop_loss.std(1), alpha=0.3)

    plt.plot(losses[:,0], adv_loss.mean(1), label=r'$\mathbb{P}_\star$')
    plt.fill_between(losses[:,0], adv_loss.mean(1) - adv_loss.std(1), adv_loss.mean(1) + adv_loss.std(1), alpha=0.3)

    if use_eps:
        plt.xlabel(r'$\delta$')
    else:
        plt.xlabel(r'$N$')

    plt.ylabel(r'$(E_{X\sim archimax} [\ell(X)] - E_{X\sim P}[\ell(X)])^2$')
    #plt.title(r'$N={}$'.format(n_data))
    plt.legend()
    plt.tight_layout()
    plt.savefig('{}/err_vs_eps_{}_n={}_std_sm={}_eps={}.pdf'.format(save_path, experiment, n_data, use_softmax, eps_max))
    plt.close('all')

    plt.plot(losses[:,0], pop_loss_mae.mean(1), label=r'$\mathbb{P}_0$')
    plt.fill_between(losses[:,0], pop_loss_mae.mean(1) - pop_loss_mae.std(1), pop_loss_mae.mean(1) + pop_loss_mae.std(1), alpha=0.3)

    plt.plot(losses[:,0], adv_loss_mae.mean(1), label=r'$\mathbb{P}_\star$')
    plt.fill_between(losses[:,0], adv_loss_mae.mean(1) - adv_loss_mae.std(1), adv_loss_mae.mean(1) + adv_loss_mae.std(1), alpha=0.3)

    if use_eps:
        plt.xlabel(r'$\delta$')
    else:
        plt.xlabel(r'$N$')
    plt.ylabel(r'$|E_{X\sim archimax} [\ell(X)] - E_{X\sim P}[\ell(X)]|$')
    #plt.title(r'$N={}$'.format(n_data))
    plt.legend()
    plt.tight_layout()
    plt.savefig('{}/err_vs_eps_{}_n={}_std_sm={}_eps={}_mae.pdf'.format(save_path, experiment, n_data, use_softmax, eps_max))
    plt.close('all')