#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import Optimizer

import time
import datetime
import numpy as np
from typing import Callable, Iterable, Tuple
import pickle
import json
import math

import sls
from SMB import SMB, SMBi

# epochs to train for
epochs = 100

# In order to use GPU
use_GPU = True

results_dir = "./results/MNIST_MLP/"


seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def set_torch_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    
##############################################################################
### DATA #####################################################################
##############################################################################


batch_size = 128

train_set = torchvision.datasets.MNIST("Datasets", train=True,
                               download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ]))

test_set = torchvision.datasets.MNIST("Datasets", train=False,
                               download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ]))

train_loader = torch.utils.data.DataLoader(train_set,
                              drop_last=True,
                              shuffle=True,
                              batch_size=batch_size)

n_batches_per_epoch = len(train_loader)



##############################################################################
### MODEL ####################################################################
##############################################################################

class Mlp(nn.Module):
    def __init__(self, input_size=784,
                 hidden_sizes=[512, 256],
                 n_classes=10,
                 bias=True, dropout=False):
        super().__init__()

        self.dropout=dropout
        self.input_size = input_size
        self.hidden_layers = nn.ModuleList([nn.Linear(in_size, out_size, bias=bias) for
                                            in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
        self.output_layer = nn.Linear(hidden_sizes[-1], n_classes, bias=bias)

    def forward(self, x):
        x = x.view(-1, self.input_size)
        out = x
        for layer in self.hidden_layers:
            Z = layer(out)
            out = F.relu(Z)

            if self.dropout:
                out = F.dropout(out, p=0.5)

        logits = self.output_layer(out)

        return logits
    
    
    
##############################################################################
### METRICS ##################################################################
##############################################################################


def softmax_loss(model, images, labels, backwards=False):
    logits = model(images)
    criterion = torch.nn.CrossEntropyLoss(reduction="mean")
    loss = criterion(logits, labels.view(-1))

    if backwards and loss.requires_grad:
        loss.backward()

    return loss

def softmax_accuracy(model, images, labels):
    logits = model(images)
    pred_labels = logits.argmax(dim=1)
    acc = (pred_labels == labels).float().mean()

    return acc

def compute_loss(model, dataset):
    metric_function = softmax_loss
    
    model.eval()

    loader = DataLoader(dataset, drop_last=False, batch_size=1024)

    score_sum = 0.
    for images, labels in loader:
        if use_GPU:
            images, labels = images.cuda(), labels.cuda()

        score_sum += metric_function(model, images, labels).item() * images.shape[0] 
            
    score = float(score_sum / len(loader.dataset))

    return score

def compute_accuracy(model, dataset):
    metric_function = softmax_accuracy
    
    model.eval()

    loader = DataLoader(dataset, drop_last=False, batch_size=1024)

    score_sum = 0.
    for images, labels in loader:
        if use_GPU:
            images, labels = images.cuda(), labels.cuda()

        score_sum += metric_function(model, images, labels).item() * images.shape[0] 
            
    score = float(score_sum / len(loader.dataset))

    return score

##############################################################################


def train_test_network(epochs, train_loader, train_set, test_set, model, criterion, optimizer, use_GPU):
    
    train_loss_list = []
    train_iter_loss_list = []
    test_acc_list = []
    run_time_list = []
    
    # Calculate metrics
    #train_loss_list.append(compute_loss(model, train_set))
    #test_acc_list.append(compute_accuracy(model, test_set))
    #run_time_list.append(end-begin)

    for epoch in range(1, epochs+1):
        
        begin = time.time()

        # training steps
        model.train()
        for batch_index, (data, target) in enumerate(train_loader):            
            
            # moves tensors to GPU
            if use_GPU:
                data, target = data.cuda(), target.cuda()     
            # clears gradients
            optimizer.zero_grad()
            # loss in batch
            loss = criterion(model, data, target)
            # backward pass for loss gradient
            loss.backward()
            
            train_iter_loss_list.append(loss.item())
            
            # update paremeters
            optimizer.step()
            
        end = time.time()
        
        train_loss = compute_loss(model, train_set)
        test_acc = compute_accuracy(model, test_set)
        
        # Calculate metrics
        train_loss_list.append(train_loss)
        test_acc_list.append(test_acc)
        run_time_list.append(end-begin)
        
        # Display loss statistics
        #print(f'Current Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}   -   Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
        
        print(epoch, end=' ')
        
    return train_loss_list, test_acc_list, run_time_list, train_iter_loss_list
   
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################



##############################################################################
# Train with SMB_samebatch optimizer
##############################################################################

opt_out = {'name':'SMB', 'autoschedule':False, 'gamma':0.05, 'beta':0.9, 
           'lr':0.5, 'c':0.1, 'eta':0.99, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = Mlp()
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss
 
optimizer = SMB(model.parameters(), lr=opt_out['lr'], c=opt_out['c'], eta=opt_out['eta'])

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))

train_loss_list = []
train_iter_loss_list = []
test_acc_list = []
run_time_list = []
model_step_ratio_list = []
learning_rate_list = []

# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

curr_lr = opt_out['lr']


for epoch in range(1, epochs+1):
    
    step_type = []
        
    begin = time.time()
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
        
        # moves tensors to GPU
        if use_GPU:
            data, target = data.cuda(), target.cuda() 
            
        # create loss closure for sls algorithm
        def closure():
            optimizer.zero_grad()
            loss = criterion(model, data, target)
            #loss.backward()
            return loss

        # forward pass
        loss, s_stype = optimizer.step(closure=closure)
        
        train_iter_loss_list.append(loss.item())
        
        step_type.append(s_stype)
        
    end = time.time()
    
    train_loss = compute_loss(model, train_set)
    test_acc = compute_accuracy(model, test_set)
        
        
    train_loss_list.append(train_loss)
    test_acc_list.append(test_acc)
    run_time_list.append(end-begin)
    model_step_ratio_list.append(sum(step_type)/len(step_type))
    learning_rate_list.append(curr_lr)
        
    # Display loss statistics
    print(f'Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}  -  Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
    print('Model steps taken:', sum(step_type), '/', len(step_type))
    
    #print(epoch, end=' ')
    

    
    # Decay learning rate
    if opt_out['autoschedule'] and epoch > 2:
        if sum(step_type)/len(step_type) >= opt_out['gamma']:
            curr_lr *= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
        elif (curr_lr/opt_out['beta']) <= 20:
            curr_lr /= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
    
    
opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'model_step_ratio':model_step_ratio_list,
                 'train_iter_loss':train_iter_loss_list,
                 'learning_rate':learning_rate_list,
                })

now = datetime.datetime.now() # current date and time
date_time = now.strftime("%Y_%m_%d_%H_%M_%S")


filename = results_dir + "{}_{}_{}_epochs_{}_autoschedule_{}_gamma_{}_beta_{}_lr_{}_c_{}_{}.json".format(opt_out['name'], 
                                                                   opt_out['data'], 
                                                                   opt_out['model'], 
                                                                   opt_out['epochs'],
                                                                   opt_out['autoschedule'],
                                                                   opt_out['gamma'],
                                                                   opt_out['beta'],
                                                                   opt_out['lr'], 
                                                                   opt_out['c'], 
                                                                   date_time
                                                                   )

with open(filename, 'w') as f:
    json.dump(opt_out, f)






##############################################################################
# Train with SMBi optimizer
##############################################################################


opt_out = {'name':'SMBi', 'autoschedule':False, 'gamma':0.05, 'beta':0.9, 
           'lr':0.5, 'c':0.1, 'eta':0.99, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = Mlp()
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss
 
optimizer = SMBi(model.parameters(), lr=opt_out['lr'], c=opt_out['c'], eta=opt_out['eta'])

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))

train_loss_list = []
train_iter_loss_list = []
test_acc_list = []
run_time_list = []
model_step_ratio_list = []
learning_rate_list = []

# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

curr_lr = opt_out['lr']


for epoch in range(1, epochs+1):
    
    step_type = []
        
    begin = time.time()
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
        
        # moves tensors to GPU
        if use_GPU:
            data, target = data.cuda(), target.cuda() 
            
        # create loss closure for sls algorithm
        def closure():
            optimizer.zero_grad()
            loss = criterion(model, data, target)
            #loss.backward()
            return loss

        # forward pass
        loss, s_stype = optimizer.step(closure=closure)
        
        train_iter_loss_list.append(loss.item())
        
        step_type.append(s_stype)
        
    end = time.time()
    
    train_loss = compute_loss(model, train_set)
    test_acc = compute_accuracy(model, test_set)
        
        
    train_loss_list.append(train_loss)
    test_acc_list.append(test_acc)
    run_time_list.append(end-begin)
    model_step_ratio_list.append(sum(step_type)/len(step_type))
    learning_rate_list.append(curr_lr)
        
    # Display loss statistics
    print(f'Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}  -  Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
    print('Model steps taken:', sum(step_type), '/', len(step_type))
    
    #print(epoch, end=' ')
    
    
    # Decay learning rate
    if opt_out['autoschedule'] and epoch > 2:
        if sum(step_type)/len(step_type) >= (opt_out['gamma']/2):
            curr_lr *= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
        elif (curr_lr/opt_out['beta']) <= 20:
            curr_lr /= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
    
    
opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'model_step_ratio':model_step_ratio_list,
                 'train_iter_loss':train_iter_loss_list,
                 'learning_rate':learning_rate_list,
                })

now = datetime.datetime.now() # current date and time
date_time = now.strftime("%Y_%m_%d_%H_%M_%S")


filename = results_dir + "{}_{}_{}_epochs_{}_autoschedule_{}_gamma_{}_beta_{}_lr_{}_c_{}_{}.json".format(opt_out['name'], 
                                                                   opt_out['data'], 
                                                                   opt_out['model'], 
                                                                   opt_out['epochs'],
                                                                   opt_out['autoschedule'],
                                                                   opt_out['gamma'],
                                                                   opt_out['beta'],
                                                                   opt_out['lr'], 
                                                                   opt_out['c'], 
                                                                   date_time
                                                                   )

with open(filename, 'w') as f:
    json.dump(opt_out, f)
    

 

##############################################################################
# Train with SLS optimizer
##############################################################################


opt_out = {'name':'SLS', 'lr':1, 'c':0.1, 'reset_option':1, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = Mlp()
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss


optimizer = sls.Sls(model.parameters(), 
                    init_step_size=opt_out['lr'], 
                    reset_option=opt_out['reset_option'], 
                    c=opt_out['c'], 
                    n_batches_per_epoch=n_batches_per_epoch
                   )

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))


train_loss_list = []
train_iter_loss_list = []
test_acc_list = []
run_time_list = []

loss = None

for epoch in range(1, epochs+1):
        
    begin = time.time()
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
        
        # moves tensors to GPU
        if use_GPU:
            data, target = data.cuda(), target.cuda() 
            
        # create loss closure for sls algorithm
        closure = lambda :  criterion(model, data, target)  
        # clears gradients
        optimizer.zero_grad()
        
        loss = optimizer.step(closure=closure)
        
        train_iter_loss_list.append(loss.item())
        
    end = time.time()
    
    train_loss = compute_loss(model, train_set)
    test_acc = compute_accuracy(model, test_set)
        
    train_loss_list.append(train_loss)
    test_acc_list.append(test_acc)
    run_time_list.append(end-begin)
        
    # Display loss statistics
    print(f'Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}  -  Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
    
    #print(epoch, end=' ')
    
opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'train_iter_loss':train_iter_loss_list,
                })


filename = results_dir + "{}_{}_{}_epochs_{}_lr_{}_c_{}_{}.json".format(opt_out['name'], 
                                                                   opt_out['data'], 
                                                                   opt_out['model'], 
                                                                   opt_out['epochs'], 
                                                                   opt_out['lr'], 
                                                                   opt_out['c'], 
                                                                   datetime.datetime.now())

with open(filename, 'w') as f:
    json.dump(opt_out, f)
    







##############################################################################
# Train with Adam optimizer
##############################################################################

opt_out = {'name':'Adam', 'lr':0.001, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 
           'GPU':use_GPU}

# model
model = Mlp()
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss

# optimizer
optimizer = optim.Adam(model.parameters(), lr = opt_out['lr'])


print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))


train_loss_list, test_acc_list, run_time_list, train_iter_loss_list = train_test_network(epochs, train_loader, train_set, test_set, model, criterion, optimizer, use_GPU)


opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'train_iter_loss':train_iter_loss_list,
                })


filename = results_dir + "{}_{}_{}_epochs_{}_lr_{}_{}.json".format(opt_out['name'], opt_out['data'], opt_out['model'], opt_out['epochs'], opt_out['lr'], datetime.datetime.now())

with open(filename, 'w') as f:
    json.dump(opt_out, f)






##############################################################################
# Train with SGD optimizer
##############################################################################

opt_out = {'name':'SGD', 'autoschedule':True, 'gamma':0.05, 'beta':0.9, 'lr':0.1, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = Mlp()
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss

# optimizer
optimizer = optim.SGD(model.parameters(), lr = opt_out['lr'])

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))


train_loss_list, test_acc_list, run_time_list, train_iter_loss_list = train_test_network(epochs, train_loader, train_set, test_set, model, criterion, optimizer, use_GPU)


opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'train_iter_loss':train_iter_loss_list,
                })


filename = results_dir + "{}_{}_{}_epochs_{}_lr_{}_{}.json".format(opt_out['name'], opt_out['data'], opt_out['model'], opt_out['epochs'], opt_out['lr'], datetime.datetime.now())

with open(filename, 'w') as f:
    json.dump(opt_out, f)
    
    
    

 


    