from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib import rcParams
import numpy as np
import pandas as pd
from pandas import DataFrame
import csv
from fractions import Fraction
from pandas.plotting import scatter_matrix
import scipy.stats
from scipy.stats.mstats import winsorize
import random
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as T
import copy
import os
import pdb
import time


#hyperparameters
num_exprs=100            #number of experiments
total_iters=4001
gamma={"SGD":1e-7,"fixed-shuffle":1e-7,"begin-shuffle":1e-7,"uniform-shuffle":1e-7}  #Stepsize
lambda0=0.01
eta0=0.1               #Initial eta
algorithms=["SGD","fixed-shuffle","begin-shuffle","uniform-shuffle"]

psi="chi-square"
psi_grad=psi           #["smoothCVaR",1]
penalty_hyps={"coeff_penalty":0.1,"eps_penalty":1.0}
epoch_eval=10
is_eval_psi=True
eval_stepsize=0.1
eval_thr=1e-4
eval_maxiter=1e+4

time_start=time.time()
# Define function L and its gradient
def L_dL(X_train,y_train,X_test,y_test,xmodel,eta,batch,psi="chi-square",psi_grad="chi-square",\
         lambda0=0.01,is_grad=True,coeff_penalty=0.1,eps_penalty=1.0):
    if isinstance(psi, list):
        if psi[0]=="smoothCVaR":
            alpha=psi[1]
            def psi(t):                
                if type(t)==torch.Tensor:
                    rr=t.clone()
                    idp=(t>0)
                    idn=~idp
                    rr[idn]=torch.log(1-alpha+alpha*torch.exp(t[idn]))
                    rr[idp]=torch.log((1-alpha)*torch.exp(-t[idp])+alpha)+t[idp]
                    return rr/alpha
                else:
                    rr=t.copy()
                    idp=(t>0)
                    idn=~idp
                    rr[idn]=np.log(1-alpha+alpha*np.exp(t[idn]))
                    rr[idp]=np.log((1-alpha)*np.exp(-t[idp])+alpha)+t[idp]
                    return rr/alpha
            def psi_grad(t):
                if type(t)==torch.Tensor:
                    rr=t.clone()
                    idp=(t>0)
                    idn=~idp
                    expt=torch.exp(t[idn])
                    rr[idn]=expt/(1-alpha+alpha*expt)
                    expt=torch.exp(-t[idp])
                    rr[idp]=1/((1-alpha)*expt+alpha)
                else:
                    rr=t.copy()
                    idp=(t>0)
                    idn=~idp
                    expt=np.exp(t[idn])
                    rr[idn]=expt/(1-alpha+alpha*expt)
                    expt=np.exp(-t[idp])
                    rr[idp]=1/((1-alpha)*expt+alpha)
                return rr
    elif psi=="chi-square":
        def psi(t):
            if type(t)==np.ndarray:
                return np.maximum(t/2+1,np.zeros_like(t))**2-1
            else:
                return torch.maximum(t/2+1,torch.zeros_like(t))**2-1   
        def psi_grad(t):
            if type(t)==np.ndarray:
                return np.maximum(t/2+1,np.zeros_like(t))
            else:
                return torch.maximum(t/2+1,torch.zeros_like(t))
    elif psi=="KL":
        def psi(t):
            if type(t)==np.ndarray:
                return np.exp(t)-1
            else:
                return torch.exp(t)-1   
        def psi_grad(t):
            if type(t)==np.ndarray:
                return np.exp(t)
            else:
                return torch.exp(t)

    if type(batch)==np.ndarray:
        batch=batch.reshape(-1)

    R=(X_train[batch].dot(xmodel)-y_train[batch])
    l=(R**2)/2+coeff_penalty*np.log(1+np.abs(xmodel)/eps_penalty).sum()
    in_psi=(l-eta)/lambda0
    LL=lambda0*(psi(in_psi).mean())+eta
    #Log-sum penalty: https://arxiv.org/abs/2103.02681
    if is_grad:
        psi_grad_vec=psi_grad(in_psi)
        eta_grad=psi_grad_vec.mean()
        x_grad=((R*psi_grad_vec).reshape(-1,1)*X_train[batch]).mean(axis=0)\
            +coeff_penalty*eta_grad*np.sign(xmodel)/(np.abs(xmodel)+eps_penalty)
        eta_grad=1-eta_grad
        return LL,x_grad,eta_grad
    return LL
    
#Spider-DRO algorithm     
def Spider_DRO(X_train,y_train,X_test,y_test,total_iters,gamma,\
               x0=None,eta0=0.1,alg=algorithms[0],epoch_eval=1,psi="chi-square",\
               psi_grad="chi-square",lambda0=0.01,eval_stepsize=0.1,eval_thr=1e-7,eval_maxiter=1e+4,\
               is_eval_psi=True,penalty_hyps={"coeff_penalty":0.1,"eps_penalty":1.0},print_progress=False):        
    n_train,d=X_train.shape
    n_test=X_test.shape[0]  
    assert d==X_test.shape[1],"X_train and X_test should have the same dimensionality."
    assert n_train==y_train.shape[0], "X_train and y_train should have the same number of samples."
    assert n_test==y_test.shape[0], "X_test and y_test should have the same number of samples."  
    
    if isinstance(psi, list):
        if psi[0]=="smoothCVaR":
            alpha=psi[1]
            def psi(t):
                if type(t)==torch.Tensor:
                    rr=t.clone()
                    idp=(t>0)
                    idn=~idp
                    rr[idn]=torch.log(1-alpha+alpha*torch.exp(t[idn]))
                    rr[idp]=torch.log((1-alpha)*torch.exp(-t[idp])+alpha)+t[idp]
                else:
                    rr=t.copy()
                    idp=(t>0)
                    idn=~idp
                    rr[idn]=np.log(1-alpha+alpha*np.exp(t[idn]))
                    rr[idp]=np.log((1-alpha)*np.exp(-t[idp])+alpha)+t[idp]
                return rr/alpha
                
            def psi_grad(t):
                if type(t)==torch.Tensor:
                    rr=t.clone()
                    idp=(t>0)
                    idn=~idp
                    expt=torch.exp(t[idn])
                    rr[idn]=expt/(1-alpha+alpha*expt)
                    expt=torch.exp(-t[idp])
                    rr[idp]=1/((1-alpha)*expt+alpha)
                else:
                    rr=t.copy()
                    idp=(t>0)
                    idn=~idp
                    expt=np.exp(t[idn])
                    rr[idn]=expt/(1-alpha+alpha*expt)
                    expt=np.exp(-t[idp])
                    rr[idp]=1/((1-alpha)*expt+alpha)
                return rr
    elif psi=="chi-square":
        def psi(t):
            if type(t)==np.ndarray:
                return np.maximum(t/2+1,np.zeros_like(t))**2-1
            elif type(t)==torch.Tensor:
                return torch.maximum(t/2+1,torch.zeros_like(t))**2-1   
        def psi_grad(t):
            if type(t)==np.ndarray:
                return np.maximum(t/2+1,np.zeros_like(t))
            elif type(t)==torch.Tensor:
                return torch.maximum(t/2+1,torch.zeros_like(t))
    elif psi=="KL":
        def psi(t):
            if type(t)==np.ndarray:
                return np.exp(t)-1
            else:
                return torch.exp(t)-1   
        def psi_grad(t):
            if type(t)==np.ndarray:
                return np.exp(t)
            else:
                return torch.exp(t)

    if x0 is None:
        x0=np.random.normal(size=d)   #Initialize model parameter
    wt_x=copy.deepcopy(x0)
    wt_eta=eta0
    L_set=[]
    Psi_set=[]
    if alg=="begin-shuffle":
        begin_seq=np.random.choice(n_train, n_train, replace=False)
    for k in range(total_iters):
        if k%epoch_eval==0:
            if print_progress:
                print("Evaluating "+str(k)+"-th iteration")
            grad=eval_thr+1
            
            R=(X_train.dot(wt_x)-y_train)
            l_full=(R**2)/2+penalty_hyps['coeff_penalty']*np.log(1+np.abs(wt_x)/penalty_hyps['eps_penalty']).sum()
            #Log-sum penalty: https://arxiv.org/abs/2103.02681
            L_now=lambda0*np.mean(psi((l_full-wt_eta)/lambda0))+wt_eta
            # pdb.set_trace()
            L_set+=[L_now]
            if np.isnan(L_now) or np.isinf(L_now):
                return L_set,Psi_set,wt_x,wt_eta
            if is_eval_psi:
                eta_opt=wt_eta
                obj_min=np.inf
                eta_iter=0
                while abs(grad)>=eval_thr and eta_iter<=eval_maxiter:
                    psi_input=(l_full-eta_opt)/lambda0
                    obj=lambda0*psi(psi_input).mean()+eta_opt
                    obj_min=min(obj,obj_min)
                    grad=1-psi_grad(psi_input).mean()
                    eta_opt-=grad*eval_stepsize
                    if print_progress:
                        print("eta="+str(eta_opt)+"; grad="+str(grad)+"; obj="+str(obj))
                    eta_iter+=1
                Psi_set+=[obj_min]
                # print("L="+str(L_now)+", Psi="+str(obj_min)+", test accuracy="+str(test_acc))
            else:
                # print("L="+str(L_now)+", test accuracy="+str(test_acc))
                if print_progress:
                    print("L="+str(L_now))
            #End of evaluation
        
        if print_progress:
            print("Updating "+str(k)+"-th iteration")
            
        if alg=="begin-shuffle":
            begin_seq=np.random.choice(n_train, n_train, replace=False)
        if alg=="uniform-shuffle":
            if k%n_train==0:
                remain_batch=np.random.choice(n_train, n_train, replace=False)
            batch=remain_batch[k%n_train]
        elif alg=="fixed-shuffle":
            batch=np.array(k%n_train)
        elif alg=="begin-shuffle":
            batch=np.array(begin_seq[k%n_train])
        else:   #SGD
            batch=np.random.choice(n_train, 1, replace=False)
            
        _,vt_x,vt_eta=L_dL(X_train,y_train,X_test,y_test,wt_x,wt_eta,batch,psi,psi_grad,lambda0,is_grad=True,\
                                   coeff_penalty=penalty_hyps['coeff_penalty'],eps_penalty=penalty_hyps['eps_penalty'])
        wt_x=wt_x-gamma*vt_x
        wt_eta=wt_eta-gamma*vt_eta
    return L_set,Psi_set,wt_x,wt_eta

def num2str_neat(num):
    a=Fraction(num)
    if abs(a.numerator)>100:
        a=Fraction(num).limit_denominator()
        return(str(a.numerator)+'/'+str(a.denominator))
    return str(num)

#https://www.telusinternational.com/insights/ai-data/article/10-open-datasets-for-linear-regression
# Find data for regresion


# Get data
#Life expectancy data: https://www.kaggle.com/datasets/kumarajarshi/life-expectancy-who
#Python code on life expectancy data: https://thecleverprogrammer.com/2021/01/06/life-expectancy-analysis-with-python/ 
life_expectancy = pd.read_csv("Life Expectancy Data.csv") #reading the file
life_expectancy.head()
life_expectancy.rename(columns = {" BMI " :"BMI", 
                              "Life expectancy ": "Life_expectancy",
                              "Adult Mortality":"Adult_mortality",
                              "infant deaths":"Infant_deaths",
                              "percentage expenditure":"Percentage_expenditure",
                              "Hepatitis B":"HepatitisB",
                              "Measles ":"Measles",
                              "under-five deaths ": "Under_five_deaths",
                              "Total expenditure":"Total_expenditure",
                              "Diphtheria ": "Diphtheria",
                              " thinness  1-19 years":"Thinness_1-19_years",
                              " thinness 5-9 years":"Thinness_5-9_years",
                              " HIV/AIDS":"HIV/AIDS",
                              "Income composition of resources":"Income_composition_of_resources"}, inplace = True)

#Fill in missing values with the corresponding column' median
life_expectancy.reset_index(inplace=True)
life_expectancy.groupby('Country').apply(lambda group: group.interpolate(method= 'linear'))
imputed_data = []
for year in list(life_expectancy.Year.unique()):
    year_data = life_expectancy[life_expectancy.Year == year].copy()
    for col in list(year_data.columns)[4:]:
        year_data[col] = year_data[col].fillna(year_data[col].dropna().median()).copy()
    imputed_data.append(year_data)
life_expectancy = pd.concat(imputed_data).copy()

#winsorizing columns
life_expectancy = life_expectancy[life_expectancy['Infant_deaths'] < 1001]
life_expectancy = life_expectancy[life_expectancy['Measles'] < 1001]
life_expectancy = life_expectancy[life_expectancy['Under_five_deaths'] < 1001]

life_expectancy.drop(['BMI'], axis=1, inplace=True)
life_expectancy['log_Percentage_expenditure'] = np.log(life_expectancy['Percentage_expenditure'])
life_expectancy['log_Population'] = np.log(life_expectancy['Population'])
life_expectancy['log_GDP'] = np.log(life_expectancy['GDP'])
life_expectancy = life_expectancy.replace([np.inf, -np.inf], 0)
life_expectancy['log_Percentage_expenditure']

life_expectancy['winz_Life_expectancy'] = winsorize(life_expectancy['Life_expectancy'], (0.05,0))
life_expectancy['winz_Adult_mortality'] = winsorize(life_expectancy['Adult_mortality'], (0,0.04))
life_expectancy['winz_Alcohol'] = winsorize(life_expectancy['Alcohol'], (0.0,0.01))
life_expectancy['winz_HepatitisB'] = winsorize(life_expectancy['HepatitisB'], (0.20,0.0))
life_expectancy['winz_Polio'] = winsorize(life_expectancy['Polio'], (0.20,0.0))
life_expectancy['winz_Total_expenditure'] = winsorize(life_expectancy['Total_expenditure'], (0.0,0.02))
life_expectancy['winz_Diphtheria'] = winsorize(life_expectancy['Diphtheria'], (0.11,0.0))
life_expectancy['winz_HIV/AIDS'] = winsorize(life_expectancy['HIV/AIDS'], (0.0,0.21))
life_expectancy['winz_Thinness_1-19_years'] = winsorize(life_expectancy['Thinness_1-19_years'], (0.0,0.04))
life_expectancy['winz_Thinness_5-9_years'] = winsorize(life_expectancy['Thinness_5-9_years'], (0.0,0.04))
life_expectancy['winz_Income_composition_of_resources'] = winsorize(life_expectancy['Income_composition_of_resources'], (0.05,0.0))
life_expectancy['winz_Schooling'] = winsorize(life_expectancy['Schooling'], (0.03,0.01))

col_dict_winz={'winz_Life_expectancy':1,'winz_Adult_mortality':2,'Infant_deaths':3,'winz_Alcohol':4,
            'log_Percentage_expenditure':5,'winz_HepatitisB':6,'Measles':7,'Under_five_deaths':8,'winz_Polio':9,
            'winz_Total_expenditure':10,'winz_Diphtheria':11,'winz_HIV/AIDS':12,'log_GDP':13,'log_Population':14,
            'winz_Thinness_1-19_years':15,'winz_Thinness_5-9_years':16,'winz_Income_composition_of_resources':17,
            'winz_Schooling':18}
X_train=np.array(life_expectancy)
y_train=X_train[:,4]
y_train=(y_train-y_train.mean())/y_train.std() 
X_train=np.delete(X_train, [1,3,4], axis=1)
X_train=X_train.astype(np.float64)
X_train=(X_train-(X_train.mean(axis=0).reshape(1,-1)))/(X_train.std(axis=0).reshape(1,-1))
n_train, d=X_train.shape

y_std=1.0
random.seed(1)
np.random.seed(1)
y_train=y_train.astype(np.float64)+np.random.normal(scale=y_std,size=n_train)

n_test=413
n_train-=n_test   #2000
X_test=X_train[n_train:(n_train+n_test)]
y_test=y_train[n_train:(n_train+n_test)]
X_train=X_train[0:n_train]
y_train=y_train[0:n_train]


random.seed(1)
np.random.seed(1)
x0=np.random.normal(size=d,scale=1.0) #Initial model x
x_plot=np.array(range(0,total_iters,epoch_eval))
len_x=len(x_plot)

folder_final='DRO_results/'

if not os.path.isdir(folder_final):
    os.makedirs(folder_final)
    
results={alg:[] for alg in algorithms}
for alg in algorithms:
    results[alg]={y_type:np.zeros((num_exprs,len_x)) for y_type in ['L','Psi']}
    for expr_k in range(num_exprs):
        print('Begin '+str(expr_k)+"-th implementation of "+alg+" algorithm.")
        L_values,Psi_values,_,_=\
            Spider_DRO(X_train,y_train,X_test,y_test,total_iters,gamma[alg],x0,eta0,alg,epoch_eval,\
                       psi,psi_grad,lambda0,eval_stepsize,\
                       eval_thr,eval_maxiter,is_eval_psi,penalty_hyps,print_progress=False)                
        results[alg]['L'][expr_k]=L_values
        results[alg]['Psi'][expr_k]=Psi_values
    # np.save(folder_final+alg+'_L',results[alg]['L'])
    np.save(folder_final+alg+'_Psi',results[alg]['Psi'])
time_seconds=time.time()-time_start
print('Total time consumption: '+str(time_seconds/60)+' minutes')

xlabels={'iters':'Iteration t'}
ylabels={'L':r'$L(x_t,\eta_t)$','Psi':r'$\Psi(x_t)$'}

hyp_txt=open(folder_final+'hyperparameters.txt','w')
hyp_txt.write('number of iterations='+str(total_iters)+'\n')
for alg in algorithms:
    hyp_txt.write('Stepsize gamma for '+str(alg)+':'+str(gamma[alg])+'\n')
hyp_txt.write('Total time consumption: '+str(time_seconds/60)+' minutes.\n\n')
hyp_txt.close()

colors=['red','black','blue','green','cyan','purple','gold','lime','darkorange']
markers=['P','v','*','s','.']
legends={"SGD":"SGD","fixed-shuffle":"Fixed-shuffling",\
         "begin-shuffle":"Shuffle-once","uniform-shuffle":"Uniform-shuffling"}
label_size=16
num_size=14
lgd_size=18
percentile=95

y_type='Psi'
plt.figure().set_size_inches(5, 5)
alg_k=0
for alg in algorithms:
    if num_exprs==1:
        plt.plot(x_plot,results[alg][y_type][0],color=colors[alg_k],label=legends[alg],\
                 marker=markers[alg_k],markevery=int(len_x/(alg_k+6)))
    else:
        y_plot=results[alg][y_type].copy()
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],\
                 markevery=int(len(avg_loss)/(alg_k+10)),label=legends[alg])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")
    alg_k+=1
        
plt.legend(prop={'size':lgd_size},loc=1)
plt.xlabel('Iteration t')
plt.ylabel(ylabels[y_type])
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.2)
plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
plt.grid(True)
plt.savefig(folder_final+'DRO_Results.png',dpi=200)


if False:
    folder_now='./DRO_results_UnifShufflingOnly/Psi_values/'
    Psi_SGD=np.load(folder_now+'SGD.npy')
    Psi_uniform=np.load(folder_now+'Uniform.npy')
    x_plot=range(Psi_SGD.shape[1])
    colors=['red','black','blue','green','cyan','purple','gold','lime','darkorange']
    markers=['P','v','*','s','.']
    legends=["SGD","Uniform-shuffling"]
    label_size=16
    num_size=14
    lgd_size=18
    percentile=95
    
    plt.figure().set_size_inches(5, 5)
    alg_k=0
    for y_plot in [Psi_SGD,Psi_uniform]:
        y_plot=y_plot.copy()
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],\
                 markevery=int(len(avg_loss)/(alg_k+10)),label=legends[alg_k])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")
        alg_k+=1
    plt.legend(prop={'size':lgd_size},loc=1)
    plt.xlabel('Iteration t')
    plt.ylabel(r'$\Psi(x_t)$')
    plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
    plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
    plt.gcf().subplots_adjust(bottom=0.15)
    plt.gcf().subplots_adjust(left=0.2)
    plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
    plt.grid(True)
    plt.savefig(folder_now+'DRO_tmp.png',dpi=200)

    


