#Code is adapted from https://github.com/kuangliu/pytorch-cifar/blob/master/main.py 

import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
# import copy
import random
import os
# import argparse
from Resnet import ResNet18
import gc
import time
import matplotlib.pyplot as plt

import pdb

#Hyperparameters
if True:
    total_iters=6001
    batchsize=200
    eta={}   #Stepsizes for each algorithm
    eta1=0.1**3
    eta["SGD"]=eta1
    eta["fixed-shuffle"]=eta1
    eta["begin-shuffle"]=eta1
    eta["uniform-shuffle"]=eta1
    between_eval=int(50000/batchsize)   #The number of iterations between two consequent evaluations
    num_exprs=10

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("I am using device "+str(device),flush=True)

# Data
print('==> Preparing data..',flush=True)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
n=len(trainset)
train_inputs=torch.zeros((n,3,32,32))
train_targets=torch.zeros(n).long()
for i in range(n):
    train_inputs[i]=trainset[i][0].clone()
    train_targets[i]=trainset[i][1]
del trainset
gc.collect()

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
n_test=len(testset)
if torch.cuda.is_available():
    test_inputs=torch.zeros((n_test,3,32,32))
    test_targets=torch.zeros(n_test).long()
    for i in range(n_test):
        test_inputs[i]=testset[i][0].clone()
        test_targets[i]=testset[i][1]
else:
    n_test_each=10
    test_inputs=torch.zeros((n_test_each*10,3,32,32))
    test_targets=torch.zeros(n_test_each*10).long()
    counts=np.zeros(10)
    k=0
    for i in range(n_test):
        y=testset[i][1]
        if counts[y]<n_test_each:
            test_inputs[k]=testset[i][0].clone()
            test_targets[k]=y
            counts[y]+=1
            k+=1
        if np.mean(counts)>=n_test_each:
            break
del testset
gc.collect()

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

criterion = nn.CrossEntropyLoss()
algorithms=["SGD","fixed-shuffle","begin-shuffle","uniform-shuffle"]
#"fixed-shuffle": all permutations pi^{(t)} equal to the data's original sequence
#"begin-shuffle": all permutations pi^{(t)} equal to a random sequence obtained before running algorithm
#"uniform-shuffle": Each permutation pi^{(t)}==uniformly obtained, i.i.d..

x_plot=list(range(0,total_iters,between_eval))
len_x_plot=len(x_plot)
train_losses={alg:np.zeros((num_exprs,len_x_plot)) for alg in algorithms}
test_losses={alg:np.zeros((num_exprs,len_x_plot)) for alg in algorithms}
train_accs={alg:np.zeros((num_exprs,len_x_plot)) for alg in algorithms}
test_accs={alg:np.zeros((num_exprs,len_x_plot)) for alg in algorithms}

minutes_algs={alg:0 for alg in algorithms}
for expr_k in range(num_exprs):
    for alg in algorithms:
        i_remain=[]
        net = ResNet18()
        net=net.to(device)
        if device == 'cuda':
            net = torch.nn.DataParallel(net)
            cudnn.benchmark = True
        optimizer = torch.optim.SGD(net.parameters(), lr=eta[alg])
        time_start=time.time()
        if alg=="begin-shuffle":
            i_begin=random.sample(range(n),n)
        eval_k=0                       #The eval_k-th evaluation currently.
        for t in range(total_iters):
            if alg=="SGD":
                i=random.sample(range(n),batchsize)
            elif alg=="fixed-shuffle":
                if not i_remain:
                    gc.collect()
                    i_remain=list(range(n))
                i=i_remain[:batchsize]
                del i_remain[:batchsize]
            elif alg=="begin-shuffle":
                if not i_remain:
                    gc.collect()
                    i_remain=i_begin.copy()
                i=i_remain[:batchsize]
                del i_remain[:batchsize]
            else: #uniform-shuffle
                if not i_remain:
                    gc.collect()
                    i_remain=random.sample(range(n),n)
                i=i_remain[:batchsize]
                del i_remain[:batchsize]
            input1, target1 =train_inputs[i].clone().to(device), train_targets[i].clone().to(device)
            output1 = net(input1)
            loss1 = criterion(output1, target1)
            loss1.backward()
            optimizer.step()
    
            if t % between_eval==0:
                with torch.no_grad():
                    train_loss_tmp=0
                    train_acc_tmp=0        
                    for batch_k in range(100):
                        batch_index=500*batch_k
                        batch_index=range(batch_index,batch_index+500)
                        train_input1, train_target1=train_inputs[batch_index].clone().to(device),train_targets[batch_index].clone().to(device)
                        train_outputs = net(train_input1)
                        train_loss_tmp+=criterion(train_outputs, train_target1).item()
                        _, train_predicted=train_outputs.max(1)
                        train_acc_tmp+=torch.sum(train_predicted==train_target1).item()
                    train_losses[alg][expr_k,eval_k]=train_loss_tmp/n
                    train_accs[alg][expr_k,eval_k]=train_acc_tmp/n
            
                    test_loss_tmp=0
                    test_acc_tmp=0
                    for batch_k in range(100):
                        batch_index=100*batch_k
                        batch_index=range(batch_index,batch_index+100)
                        test_input1, test_target1=test_inputs[batch_index].clone().to(device),test_targets[batch_index].clone().to(device)
                        test_outputs=net(test_input1)
                        test_loss_tmp+=criterion(test_outputs, test_target1).item()
                        _, test_predicted=test_outputs.max(1)
                        test_acc_tmp+=torch.sum(test_predicted==test_target1).item()
                    test_losses[alg][expr_k,eval_k]=test_loss_tmp/n_test
                    test_accs[alg][expr_k,eval_k]=test_acc_tmp/n_test        
                print(str(expr_k)+"-th implement of "+alg+": Iteration "+str(t)\
                      +", train loss="+str(train_losses[alg][expr_k,eval_k])\
                      +", train accuracy="+str(train_accs[alg][expr_k,eval_k])\
                      +", test loss="+str(test_losses[alg][expr_k,eval_k])\
                      +", test accuracy="+str(test_accs[alg][expr_k,eval_k]),flush=True)
                eval_k+=1
        minutes_algs[alg]+=(time.time()-time_start)/60


# result_dir="Cifar10_result_stepsize"+str(eta1)+"_server/"
result_dir="Cifar10_result/"
if not os.path.isdir(result_dir):
    os.makedirs(result_dir)
    
for alg in algorithms:
    np.save(result_dir+"TrainLoss_"+alg+".npy",train_losses[alg])
    np.save(result_dir+"TestLoss_"+alg+".npy",test_losses[alg])
    np.save(result_dir+"TrainAcc_"+alg+".npy",train_accs[alg])
    np.save(result_dir+"TestAcc_"+alg+".npy",test_accs[alg])

label_size=16
num_size=14
lgd_size=18
percentile=90
colors=['red','black','blue','green','cyan','purple','gold','lime','darkorange']
markers=['P','v','*','s','.']
legends={"SGD":"SGD","fixed-shuffle":"Fixed-shuffling","begin-shuffle":"Shuffle-once","uniform-shuffle":"Uniform-shuffling"}

hyp_txt=open(result_dir+'hyperparameters.txt','w')
hyp_txt.write('number of experiments='+str(num_exprs)+'\n')
hyp_txt.write('number of iterations='+str(total_iters)+'\n')
for alg in algorithms:
    hyp_txt.write('Stepsize eta for '+str(alg)+':'+str(eta[alg])+'\n')
hyp_txt.write('\n')
for alg in algorithms:
    hyp_txt.write('Time consumption for '+str(alg)+':'+str(minutes_algs[alg])+' minutes.\n')
total_minutes=np.sum([minutes_algs[alg] for alg in algorithms])
hyp_txt.write('Total time consumption:'+str(total_minutes)+' minutes.\n')  #about 5.7 hours
hyp_txt.close()

plt.figure()
alg_k=0
for alg in algorithms:
    if num_exprs==1:
        plt.plot(x_plot,train_losses[alg][0],color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
    else:
        y_plot=train_losses[alg].copy()
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")        
    alg_k+=1
plt.legend(prop={'size':lgd_size})
plt.xlabel("Iteration t")
plt.ylabel("Training Loss")
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)
plt.grid(True)
plt.savefig(result_dir+"TrainLoss.png",dpi=200)
plt.close()

plt.figure()
alg_k=0
for alg in algorithms:
    if num_exprs==1:
        plt.plot(x_plot,test_losses[alg][0],color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
    else:
        y_plot=test_losses[alg].copy()
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")              
    alg_k+=1
plt.legend(prop={'size':lgd_size})
plt.xlabel("Iteration t")
plt.ylabel("Test Loss")
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.ticklabel_format(style='sci',scilimits=(0,0),axis='y')
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)
plt.grid(True)
plt.savefig(result_dir+"TestLoss.png",dpi=200)
plt.close()

plt.figure()
alg_k=0
for alg in algorithms:
    if num_exprs==1:
        plt.plot(x_plot,train_accs[alg][0]*100,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
    else:
        y_plot=train_accs[alg]*100
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")              
    alg_k+=1
plt.legend(prop={'size':lgd_size})
plt.xlabel("Iteration t")
plt.ylabel("Trainining Accuracy (%)")
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)
plt.grid(True)
plt.savefig(result_dir+"TrainAccuracy.png",dpi=200)
plt.close()

plt.figure()
alg_k=0
for alg in algorithms:
    if num_exprs==1:
        plt.plot(x_plot,test_accs[alg][0]*100,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
    else:
        y_plot=test_accs[alg]*100
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")              
    alg_k+=1
plt.legend(prop={'size':lgd_size})
plt.xlabel("Iteration t")
plt.ylabel("Test Accuracy (%)")
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)
plt.grid(True)
plt.savefig(result_dir+"TestAccuracy.png",dpi=200)
plt.close()

        
