import os
import sys
from platform import python_version
import numpy as np
import torch
from torch.optim.lr_scheduler import MultiStepLR
import time
import scipy.io as sio
import copy
import pandas as pd
import pickle
from DTI_dataset import N_n_ToyDataSet
from gae2 import GDAE_N_n, DAE_DTI
from gae2_trainer import gae_N_n_trainer_batchall
from gae2_score_estimation import gae_N_n_estimate_score_error, gae_N_n_estimate_score_error_truncated, DTI_dae_estimate_score_error_truncated
from tensor_data_util import *
from torch_batch_svd import svd
from DTI_meanShift2 import N_n_meanShift_dae, DTI_meanShift_vectordae
import argparse

import matplotlib.pyplot as plt
import matplotlib.lines as lines
from mpl_toolkits.mplot3d import Axes3D

def S_linear_field(x, center, S0, velocity):
    # center.size() == (1, 2)
    # x.size() == (bs, 2)
    # S0.size() == (2, 2)
    # velocity.size() == (2, 2, 2) ~ (input dim x symmetric matrics)
    N = x.shape[0]
    S = torch.zeros(N,2,2).cuda() + S0.unsqueeze(0)
    S += torch.einsum('ni, ijk -> njk', (x - center), velocity)
    return S

def S_polar_field(x, center, S0, velocity, offset=0, D0=0, reverse=False):
    # center.size() == (1, 2)
    # x.size() == (bs, 2)
    # S0.size() == (2, 2)
    # D0.size() == (2, 2)
    # velocity.size() == (1, 2, 2) ~ (a symmetric matrix)
    # offset: a scalar
    N = x.shape[0]
    r = ((x-center)**2).sum(dim=-1).unsqueeze(-1)
    theta = torch.atan2((x[:,1]-center[:,1]), (x[:,0]-center[:,0]))
    
    R = torch.zeros(N,2,2).cuda()
    R[:,0,0] = torch.cos(theta)
    if reverse:
        R[:,0,1] = -torch.sin(theta)
        R[:,1,0] = torch.sin(theta)
    else:
        R[:,0,1] = torch.sin(theta)
        R[:,1,0] = -torch.sin(theta)
    R[:,1,1] = torch.cos(theta)    
    D = velocity*(r.view(-1,1,1) - offset)
    S = torch.bmm(torch.bmm(R, D + D0), R.permute(0,2,1)) + S0.unsqueeze(0)
    return S


def DTI2dim2ellipseInfoFromCov(data, scale, num = 20):
    data = data.cuda()
    N = data.shape[0]
    pos = data[:,:2]
    cov = data[:,2:]
    U,S,V = svd(vector2tensor_1dim(cov))
    UtV = torch.diagonal(torch.matmul(U.permute(0,2,1),V), dim1=1, dim2=2)
    S[UtV < 0] = - S[UtV < 0]
    radii = S
    theta = torch.linspace(0,2*np.pi,num).cuda().view(1,-1)
    ellipse_points = torch.cat(
        [(radii[:,0:1]*torch.cos(theta)).view(N,num,1), 
         (radii[:,1:2]*torch.sin(theta)).view(N,num,1)], 2)
    ellipse_points = torch.matmul(ellipse_points, U.permute(0,2,1))
    return scale*ellipse_points + pos.view(N,1,2)

def draw_P2_field(x, P, scale, num=20):
    data = torch.cat((x, P), dim=1)
    ellipse_points = DTI2dim2ellipseInfoFromCov(data, scale, num = num)

    ellipse_points_np = ellipse_points.cpu().numpy()

    for i in range(len(ellipse_points_np)):
        plt.plot(ellipse_points_np[i,:,0], ellipse_points_np[i,:,1])

def P2_field_generator(type='1'):
    if type == '1':
        xs = torch.linspace(-1, 1, steps=20)
        ys = torch.linspace(-1, 1, steps=20)
        x, y = torch.meshgrid(xs, ys)
        dx = xs[1] - xs[0]
        x = torch.cat((x.reshape(-1,1),y.reshape(-1,1)),dim=1).cuda()
        
        # 1
        S0 = torch.zeros(2,2).cuda()
        S0[0,0] = -2.5
        S0[1,1] = -1.5

        vel = torch.tensor([
            [[1.0, 0.0], [0.0, 1.0]], 
            [[0.0, 1.0], [1.0, 0.0]]
        ]).cuda()
        S_linear1 = S_linear_field(x, torch.zeros((1, 2)).cuda(), S0, 0.3*vel)

        P2_1 = Exp_mat(S_linear1, returnVec=True)
        
        # 2

        S0 = torch.zeros(2,2).cuda()
        S0[0,0] = -2.5
        S0[1,1] = -1.5

        vel = torch.tensor([
            [[-1.0, 0.0], [0.0, -1.0]], 
            [[0.0, 1.0], [1.0, 0.0]]
        ]).cuda()
        S_linear2 = S_linear_field(x, torch.zeros((1, 2)).cuda(), S0, 0.3*vel)

        P2_2 = Exp_mat(S_linear2, returnVec=True)

        # 3

        S0 = torch.zeros(2,2).cuda()
        S0[0,0] = -1.
        S0[1,1] = -1.
        D0 = torch.zeros(2,2).cuda()
        D0[0,0] = -2
        D0[1,1] = -1.

        vel = torch.tensor([
            [[0.0, 0.0], [0.0, 0.0]]
        ]).cuda()
        S_polar1 = S_polar_field(x, torch.zeros((1, 2)).cuda(), S0, 0.3*vel, D0=D0, reverse=True)


        P2_p1 = Exp_mat(S_polar1, returnVec=True)

        
        # 4
        vel = torch.tensor([
            [[0.0, 0.0], [0.0, 0.0]]
        ]).cuda()
        S0 = torch.zeros(2,2).cuda()
        S0[0,0] = -2.5
        S0[1,1] = -2.5
        D0 = torch.zeros(2,2).cuda()
        D0[0,0] = -0.5
        D0[1,1] = 0.5
        S_polar2 = S_polar_field(x, torch.zeros((1, 2)).cuda(), S0, 0.3*vel, D0=D0)


        P2_p2 = Exp_mat(S_polar2, returnVec=True)

        x = torch.cat([
            x+torch.tensor([[1.+dx, -1]]).cuda(),
            x+torch.tensor([[-1, 1.+dx]]).cuda(),
            x+torch.tensor([[1.+dx, 1.+dx]]).cuda(),
            x+torch.tensor([[-1, -1]]).cuda()
        ], dim=0)

        P = torch.cat([
            P2_1, P2_2, P2_p1, P2_p2
        ], dim=0)
    else:
        raise NotImplementedError
    return (x, P, dx)

def corrupt_tensor(x, noise_type, noise_level):
    if noise_type == 'tangent_Gaussian':
        cov_sqrt = get_sqrt_sym(x)
        cov_noise_coeff = torch.cuda.FloatTensor([1.0, 1.0/np.sqrt(2.0), 1.0]).view(1,3)
        epsilon_cov = torch.cuda.FloatTensor(x.shape[0],x.shape[1]).normal_(0.0, 
                                           noise_level) * cov_noise_coeff
        Exp_epsilon = Exp(epsilon_cov, returnVec = False)
        x_cor = group_action(Exp_epsilon, cov_sqrt.permute(0,2,1), returnVec = True)
    else:
        NotImplementedError
    return x_cor

def gae_Toy_N_n_trainer_batchall(dataset, model, optimizer, scheduler, max_iter_num, use_gpu = True, 
                saveModel = False, printEpochPeriod = 1000, weight_mode = None):
    N = dataset.posAndCov.size()[0]
    x = dataset.posAndCov.clone()
    x_weight = None
    covInv_sqrt = dataset.covInv_sqrt.clone()
    cov_sqrt = dataset.cov_sqrt.clone()
    cov_eigvec = dataset.cov_eigvec
    cov_eigval = dataset.cov_eigval
    if use_gpu:
        x = x.cuda()
        covInv_sqrt = covInv_sqrt.cuda()
        cov_sqrt = cov_sqrt.cuda()
        cov_eigvec = cov_eigvec.cuda()
        cov_eigval = cov_eigval.cuda()
    
    cov_logJacobian = [None]*len(x)
    def closure():
        optimizer.zero_grad()
        loss = model.calculate_loss(x, cov_sqrt, covInv_sqrt, cov_logJacobian, cov_eigvec, cov_eigval, weight = x_weight)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        optimizer.step(closure)
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            print("iter: %d ---- time %.1f" % (epoch, time.time() - start))
            loss = model.calculate_loss(x, cov_sqrt, covInv_sqrt, cov_logJacobian, cov_eigvec, cov_eigval)
            print('loss: %f' % (loss.data/N))
        scheduler.step()
    if saveModel:
        # deep copy the model
        model_wts = copy.deepcopy(model.state_dict())
        return model_wts
    return loss.data/N

def dae_Toy_trainer_batchall(dataset, model, optimizer, scheduler, max_iter_num, use_gpu = True, 
                saveModel = False, printEpochPeriod = 1000, useFixedNoise = True):
    N = dataset.posAndCov.size()[0]
    x = dataset.posAndCov.clone()
    if use_gpu:
        x = x.cuda()
    
    if useFixedNoise:
        # sample fixed noise
        if use_gpu:
            if isinstance(model.noise_std, float):
                epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, model.noise_std)
            else:
                epsilon = torch.cuda.FloatTensor(x.size()).normal_(0.0, 1.0) * model.noise_std.cuda()
        else:
            epsilon = torch.FloatTensor(x.size()).normal_(0.0, 1.0) * model.noise_std
        
        # print initial loss
        loss = model.calculate_loss(x, epsilon)
        print('initial loss: %f' % (loss.data/N))
    
    def closure():
        optimizer.zero_grad()
        if useFixedNoise:
            loss = model.calculate_loss(x, epsilon)
        else:
            loss = model.calculate_loss(x)
        #print('loss:', loss.data[0])
        loss.backward()
        return loss
    
    start = time.time()
    for epoch in range(max_iter_num):
        optimizer.step(closure)
        if (epoch % printEpochPeriod == 0) or epoch == max_iter_num-1:
            print("iter: %d ---- time %.1f" % (epoch, time.time() - start))
            if useFixedNoise:
                loss = model.calculate_loss(x, epsilon)
            else:
                loss = model.calculate_loss(x)
            print('loss: %f' % (loss.data/N))
        scheduler.step()
    if saveModel:
        # deep copy the model
        model_wts = copy.deepcopy(model.state_dict())
        return model_wts
    return loss.data/N


def filtering_exp_gdae(gdae, input, data_true, step_size, max_iter):
    pos_threshold = 0.002*step_size#*float(torch.sqrt(const1))
    cov_threshold = 0.0001*step_size
    pos_group_threshold = 10#*float(torch.sqrt(const1))
    cov_group_threshold = 0.5
    eliminate_threshold = 10
    lowerbound_pos, _ = torch.min(input[:,:2],dim=0)
    upperbound_pos, _ = torch.max(input[:,:2],dim=0)

    U,S,V = svd(vector2tensor_1dim(input[:,2:]))
    alpha = 0.1
    lowerbound_s = torch.min(S)
    upperbound_s = torch.max(S)
    print(lowerbound_s)
    print(upperbound_s)
    lowerbound = torch.cat((lowerbound_pos.view(-1), lowerbound_s.view(-1)), 0)
    upperbound = torch.cat((upperbound_pos.view(-1), upperbound_s.view(-1)), 0)
    pos_metric_choice = None
    pos_metric_choice = 'riemannian'
    ms_gdae = N_n_meanShift_dae(gdae, step_size, pos_threshold = pos_threshold, cov_threshold = cov_threshold,
                    pos_group_threshold = pos_group_threshold, cov_group_threshold = cov_group_threshold, 
                               eliminate_threshold = eliminate_threshold, 
                                lowerBound = lowerbound, upperBound = upperbound, pos_dim = 2, 
                                pos_metric = pos_metric_choice)
    filteringResults = ms_gdae.run_meanShift(input.clone(), max_iter, save_prefix = '_toy', 
            pos_metric = pos_metric_choice, save_iter=1, cleanInput = data_true.clone(), error_weight = None)
    new_ai_errorsSet = []
    for i in range(len(filteringResults.shiftedPointsSet)):
        temp = Log_mat(torch.bmm(torch.inverse(vector2tensor_1dim(filteringResults.shiftedPointsSet[i][:,2:])),
                                 vector2tensor_1dim(data_true[:,2:])))
        new_ai_errorsSet.append((temp*temp).sum().item()/temp.shape[0])
    
    return filteringResults, np.array(new_ai_errorsSet)

def filtering_exp_dae(dae, input, data_true, step_size, max_iter):
    pos_threshold = 0.002*step_size
    cov_threshold = 0.0001*step_size
    pos_group_threshold = 10
    cov_group_threshold = 0.5
    eliminate_threshold = 10
    lowerbound, _ = torch.min(input,dim=0)
    upperbound, _ = torch.max(input,dim=0)
    pos_metric_choice = None
    ms_dae = DTI_meanShift_vectordae(dae, step_size, pos_threshold = pos_threshold, cov_threshold = cov_threshold,
                    pos_group_threshold = pos_group_threshold, cov_group_threshold = cov_group_threshold, 
                               eliminate_threshold = eliminate_threshold, 
                                lowerBound = lowerbound, upperBound = upperbound, pos_dim = 2, pos_metric = pos_metric_choice)

    filteringResults = ms_dae.run_meanShift(input.clone(), max_iter, save_prefix = '_toy', 
                                            save_iter=10, cleanInput = data_true.clone(), error_weight = None)
    new_ai_errorsSet = []
    for i in range(len(filteringResults.shiftedPointsSet)):
        temp = Log_mat(torch.bmm(torch.inverse(vector2tensor_1dim(filteringResults.shiftedPointsSet[i][:,2:])),
                                 vector2tensor_1dim(data_true[:,2:])))
        new_ai_errorsSet.append((temp*temp).sum().item()/temp.shape[0])
    
    
    return filteringResults, np.array(new_ai_errorsSet)

def main(args):
    torch.cuda.set_device(args.gpu)
    os.makedirs(args.save_path, exist_ok=True)
    if args.mode == 'data_generation':
        # data generation
        dict_synthetic_data = {}
        for data_num in ['8']:
            for noise_type in ['tangent_Gaussian']:
                for noise_level in [0.2, 0.05, 0.02]:
                    for run in [1, 2, 3, 4, 5]:
                        # generate data
                        save_path = os.path.join(args.save_path, f'{data_num}_{noise_type}_{noise_level}_{run}')
                        os.makedirs(save_path, exist_ok=True)
                        
                        (x, P, dx) = P2_field_generator(type=data_num)
                        data_true = torch.cat((x,P), dim=1)
                        plt.figure()
                        draw_P2_field(data_true[:,:2], data_true[:,2:], args.ellipse_scale)
                        plt.savefig(os.path.join(save_path, 'clean_data.pdf'), 
                                    format='pdf', bbox_inches='tight')
                        plt.cla()
                        P_cor = corrupt_tensor(P, noise_type, noise_level)
                        posAndCov = torch.cat((x,P_cor), dim=1)
                        
                        plt.figure()
                        draw_P2_field(posAndCov[:,:2], posAndCov[:,2:], args.ellipse_scale)
                        plt.savefig(os.path.join(save_path, 'corrupted_data.pdf'), 
                                    format='pdf', bbox_inches='tight')
                        plt.cla()
                        
                        torch.save(x.cpu(), os.path.join(save_path, 'x.pt'))
                        torch.save(P.cpu(), os.path.join(save_path, 'P.pt'))
                        torch.save(P_cor.cpu(), os.path.join(save_path, 'P_noisy.pt'))
                        
    else:
        for noise_level in args.noise_levels:
            for run in args.runs:
                save_path = os.path.join(args.save_path, f'{args.data}_{args.noise_type}_{noise_level}_{run}')
                # load data
                x = torch.load(os.path.join(save_path, 'x.pt')).cuda()
                P = torch.load(os.path.join(save_path, 'P.pt')).cuda()
                P_cor = torch.load(os.path.join(save_path, 'P_noisy.pt')).cuda()
                dx = x[1,1]-x[0,1]
                data_true = torch.cat((x,P), dim=1)
                posAndCov = torch.cat((x,P_cor), dim=1)
                dataset = N_n_ToyDataSet(posAndCov, dx=dx)
                data_rms = torch.sqrt(torch.mean(posAndCov**2, dim=0)).view(1,-1)


                if args.method == 'GDAE':
                    covInv = torch.inverse(vector2tensor_1dim(P))
                    cov_sqrt = get_sqrt_sym(P)



                    ### get constant for setting noise_std and covMetricCoeff for balancing noise scale to be 
                    ### comparable to dae case
                    if args.cov_coeff:
                        posMetric_sq = torch.bmm(covInv, covInv)
                        const1 = torch.sum(covInv * torch.eye(2).view(1,2,2).cuda()) \
                        / torch.sum(posMetric_sq * torch.eye(2).view(1,2,2).cuda())
                        print(const1)
                        posMetricInv = torch.bmm(cov_sqrt.permute(0,2,1), cov_sqrt)
                        posMetricInv_sq = torch.bmm(posMetricInv, posMetricInv)
                        const2 = torch.sum(posMetricInv_sq * torch.eye(2).view(1,2,2).cuda()) / \
                        torch.sum(posMetricInv * torch.eye(2).view(1,2,2).cuda())
                        print(const2)
                        covCoeff = 1.0 / const1.cpu().numpy() * 2.0
                    else:
                        covCoeff = 1.

                    dim = [5, args.h_dims]
                    approx_order = 1
                    use_exp_map_sqrt = args.use_exp_map_sqrt
                    use_exp_map_corrupt = False

                    gdae = GDAE_N_n(dim, args.num_hidden_layers, args.noise_std / np.sqrt(2.0), covCoeff = covCoeff,
                                     posMetricFunc = posMetric_func_N_n, pos_dim = 2, useLeakyReLU = False, approx_order = approx_order,
                                   use_exp_map_sqrt = use_exp_map_sqrt, use_exp_map_corrupt = use_exp_map_corrupt)
                    gdae.autoencoder[0].weight.data /= data_rms.cpu()
                    gdae = gdae.cuda()

                    optimizer = torch.optim.Adam( gdae.parameters(), lr=args.lr, weight_decay=args.weight_decay)
                    scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)

                    model_wts = gae_Toy_N_n_trainer_batchall(dataset, gdae, optimizer, scheduler, args.epochs, use_gpu = True, 
                                saveModel = True, printEpochPeriod = 1000, weight_mode = None)

                    #### post iterations for GDAE
                    max_iter_num = 100
                    score_error_set = []
                    model_set = []
                    optimizer = torch.optim.Adam( gdae.parameters(), lr=1e-6, weight_decay=1e-6)
                    for i in range(10):
                        model_wts = gae_Toy_N_n_trainer_batchall(dataset, gdae, optimizer, scheduler, max_iter_num, use_gpu = True, 
                                    saveModel = True, printEpochPeriod = 100, weight_mode = None)

                        score_error_set.append(
                            gae_N_n_estimate_score_error_truncated(posAndCov, covInv, 
                                                           cov_sqrt, [0,len(posAndCov)], gdae, 
                                                           printError = False)
                        )
                        if i == 0:
                            best_model = model_wts
                            min_val = score_error_set[-1]
                        elif score_error_set[-1] < min_val:
                            best_model = model_wts
                            min_val = score_error_set[-1]
                        print(score_error_set[-1])

                    gdae.load_state_dict(best_model)

                    filtering_results_set = []
                    min_ai_error_set = []
                    argmin_ai_error_set = []
                    ai_error_set = []
                    for step_size in args.step_sizes:
                        filteringResults, ai_errorsSet = filtering_exp_gdae(gdae, posAndCov, data_true, step_size, args.n_iter)
                        filtering_results_set.append(filteringResults)
                        min_ai_error_set.append(np.min(ai_errorsSet))
                        argmin_ai_error_set.append(np.argmin(ai_errorsSet))
                        ai_error_set.append(ai_errorsSet)

                elif args.method == 'DAE':
                    dim = [5, args.h_dims]
                    dae = DAE_DTI(dim, args.num_hidden_layers, args.noise_std, useLeakyReLU = False, pos_dim = 2)
                    dae.autoencoder[0].weight.data /= data_rms.cpu()
                    dae = dae.cuda()

                    optimizer = torch.optim.Adam( dae.parameters(), lr=args.lr, weight_decay=args.weight_decay)
                    scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
                    model_wts = dae_Toy_trainer_batchall(dataset, dae, optimizer, scheduler, args.epochs, use_gpu = True, 
                                    saveModel = True, printEpochPeriod = 100, useFixedNoise = False)


                    #### post iterations for DAE
                    max_iter_num = 100
                    score_error_set = []
                    model_set = []
                    optimizer = torch.optim.Adam( dae.parameters(), lr=1e-6, weight_decay=1e-6)
                    for i in range(10):
                        model_wts = dae_Toy_trainer_batchall(dataset, dae, optimizer, scheduler, max_iter_num, use_gpu = True, 
                                    saveModel = True, printEpochPeriod = 100, useFixedNoise = False)
                        score_error_set.append(
                            DTI_dae_estimate_score_error_truncated(posAndCov, [0,len(posAndCov)], dae, 
                                                           printError = False)
                        )
                        if i == 0:
                            best_model = model_wts
                            min_val = score_error_set[-1]
                        elif score_error_set[-1] < min_val:
                            best_model = model_wts
                            min_val = score_error_set[-1]
                        print(score_error_set[-1])
                    dae.load_state_dict(best_model)

                    filtering_results_set = []
                    min_ai_error_set = []
                    argmin_ai_error_set = []
                    ai_error_set = []
                    for step_size in args.step_sizes:
                        filteringResults, ai_errorsSet = filtering_exp_dae(dae, posAndCov, data_true, step_size, args.n_iter)
                        filtering_results_set.append(filteringResults)
                        min_ai_error_set.append(np.min(ai_errorsSet))
                        argmin_ai_error_set.append(np.argmin(ai_errorsSet))
                        ai_error_set.append(ai_errorsSet)
                # save best filtering results as a figure
                best_idx = np.argmin(min_ai_error_set)
                data_filtered = filtering_results_set[best_idx].shiftedPointsSet[argmin_ai_error_set[best_idx]]

                plt.figure()
                draw_P2_field(data_true[:,:2], data_filtered[:,2:], args.ellipse_scale)
                filename = args.method+'_noise_std'+str(args.noise_std)+'_epochs'+str(args.epochs)+'_run'+str(run)\
                +'_datatype'+str(args.data)+'_scale'+str(args.ellipse_scale)+'_noise'+str(noise_level)
                if args.no_cov_coeff:
                    filename = 'nocoeff_'+filename
                if args.no_use_exp_map_sqrt:
                    filename = 'noExpSqrt_'+filename
                plt.savefig(args.save_path+'/'+filename+'.pdf', format='pdf', bbox_inches='tight')

                

                res = {
                    "filtering_results_set": filtering_results_set,
                    "ai_error_set": ai_error_set,
                    "min_ai_error": min(min_ai_error_set),
                    "min_score_error_est": min_val
                }
                filename_pk = args.save_path+'/'+filename+'.pickle'

                with open(filename_pk, 'wb') as handle:
                    pickle.dump(res, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    required_arg = parser.add_argument_group('required arguments')
    required_arg.add_argument('--data_path', type=str, default='./N2_filtering_toy')
    required_arg.add_argument('--save_path', type=str, default='./N2_filtering_toy')
    required_arg.add_argument('--data', type=str, default='8')
    required_arg.add_argument('--noise_type', type=str, default='tangent_Gaussian')
    required_arg.add_argument('--noise_levels', nargs='*', type=float, default=[0.02, 0.05, 0.1, 0.2])
    required_arg.add_argument('--runs', nargs='*', type=int, default=[1,2,3,4,5])
    required_arg.add_argument('--method', type=str, default='GDAE')
    
    required_arg.add_argument('--noise_std', type=float, default=0.05)
    required_arg.add_argument('--h_dims', type=int, default=1000)
    required_arg.add_argument('--num_hidden_layers', type=int, default=2)
    required_arg.add_argument('--epochs', type=int, default=10000)
    required_arg.add_argument("--lr", type=float, default=1e-4)
    required_arg.add_argument("--weight_decay", type=float, default=1e-6)
    required_arg.add_argument("--milestones", nargs='*', type=int, default=[5000])
    required_arg.add_argument("--gamma", type=float, default=0.01)
    required_arg.add_argument('--gpu', type=int, default=0)
    required_arg.add_argument('--n_iter', type=int, default=300)
    required_arg.add_argument('--step_sizes', nargs='*', type=float, default=[0.01, 0.03, 0.1, 0.3, 1.0])
    required_arg.add_argument('--ellipse_scale', type=float, default=0.1)
    required_arg.add_argument('--mode', type=str, default=None)
    
    
    required_arg.add_argument('--no-cov-coeff',
                        action='store_true',
                        help='label smoothing')
    required_arg.add_argument('--no-use_exp_map_sqrt',
                        action='store_true',
                        help='label smoothing')

    args = parser.parse_args()
    
    args.cov_coeff = not (args.no_cov_coeff)
    args.use_exp_map_sqrt = not (args.no_use_exp_map_sqrt)
    
    main(args)