#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import Optimizer

import time
import datetime
import numpy as np
import math
import copy
from typing import Callable, Iterable, Tuple
import json
import pickle

import sls
from SMB import SMB, SMBi

seed = 43
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)



##############################################################################

# epochs to train for
epochs = 200

# In order to use GPU
use_GPU = True

results_dir = "./results/CIFAR10_RESNET34/"


##


##############################################################################
### DATA #####################################################################
##############################################################################


batch_size = 128

transform_function = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
                                        ])

train_set = torchvision.datasets.CIFAR10(root='Datasets',
                                         train=True,
                                         download=True,
                                         transform=transform_function
                                        )

test_set = torchvision.datasets.CIFAR10(root="Datasets", 
                                        train=False,
                                        download=True,
                                        transform=transform_function
                                       )

train_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                          )

n_batches_per_epoch = len(train_loader)

##############################################################################
### MODEL ####################################################################
##############################################################################

class ResNet(nn.Module):
    
    def __init__(self, num_blocks, num_classes=10):
        super().__init__()
        block = BasicBlock
        self.in_planes = 64

        self.conv1 = nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes,
            planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False), nn.BatchNorm2d(self.expansion * planes))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
    
    
##############################################################################
### METRICS ##################################################################
##############################################################################


def softmax_loss(model, images, labels, backwards=False):
    logits = model(images)
    criterion = torch.nn.CrossEntropyLoss(reduction="mean")
    loss = criterion(logits, labels.view(-1))

    if backwards and loss.requires_grad:
        loss.backward()

    return loss

def softmax_accuracy(model, images, labels):
    logits = model(images)
    pred_labels = logits.argmax(dim=1)
    acc = (pred_labels == labels).float().mean()

    return acc

def compute_loss(model, dataset):
    metric_function = softmax_loss
    
    model.eval()

    loader = DataLoader(dataset, drop_last=False, batch_size=1024)

    score_sum = 0.
    for images, labels in loader:
        if use_GPU:
            images, labels = images.cuda(), labels.cuda()

        score_sum += metric_function(model, images, labels).item() * images.shape[0] 
            
    score = float(score_sum / len(loader.dataset))

    return score

def compute_accuracy(model, dataset):
    metric_function = softmax_accuracy
    
    model.eval()

    loader = DataLoader(dataset, drop_last=False, batch_size=1024)

    score_sum = 0.
    for images, labels in loader:
        if use_GPU:
            images, labels = images.cuda(), labels.cuda()

        score_sum += metric_function(model, images, labels).item() * images.shape[0] 
            
    score = float(score_sum / len(loader.dataset))

    return score

##############################################################################

def train_test_network(epochs, train_loader, train_set, test_set, model, criterion, optimizer, use_GPU):
    
    train_loss_list = []
    train_iter_loss_list = []
    test_acc_list = []
    run_time_list = []
    
    # Calculate metrics
    #train_loss_list.append(compute_loss(model, train_set))
    #test_acc_list.append(compute_accuracy(model, test_set))
    #run_time_list.append(end-begin)

    for epoch in range(1, epochs+1):
        
        begin = time.time()

        # training steps
        model.train()
        for batch_index, (data, target) in enumerate(train_loader):            
            
            # moves tensors to GPU
            if use_GPU:
                data, target = data.cuda(), target.cuda()     
            # clears gradients
            optimizer.zero_grad()
            # loss in batch
            loss = criterion(model, data, target)
            # backward pass for loss gradient
            loss.backward()
            
            train_iter_loss_list.append(loss.item())
            
            # update paremeters
            optimizer.step()
            
        end = time.time()
        
        train_loss = compute_loss(model, train_set)
        test_acc = compute_accuracy(model, test_set)
        
        # Calculate metrics
        train_loss_list.append(train_loss)
        test_acc_list.append(test_acc)
        run_time_list.append(end-begin)
        
        # Display loss statistics
        #print(f'Current Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}   -   Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
        
        print(epoch, end=' ')
        
    return train_loss_list, test_acc_list, run_time_list, train_iter_loss_list

##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################
##############################################################################


##############################################################################
# Train with SMB optimizer
##############################################################################

opt_out = {'name':'SMB', 'autoschedule':False, 'gamma':0.05, 'beta':0.9, 
           'lr':1, 'c':0.1, 'eta':0.99, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = ResNet([3, 4, 6, 3], num_classes=10)
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss
 
optimizer = SMB(model.parameters(), lr=opt_out['lr'], c=opt_out['c'], eta=opt_out['eta'])

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))

train_loss_list = []
train_iter_loss_list = []
test_acc_list = []
run_time_list = []
model_step_ratio_list = []
learning_rate_list = []

# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

curr_lr = opt_out['lr']


for epoch in range(1, epochs+1):
    
    step_type = []
        
    begin = time.time()
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
        
        # moves tensors to GPU
        if use_GPU:
            data, target = data.cuda(), target.cuda() 
            
        # create loss closure for sls algorithm
        def closure():
            optimizer.zero_grad()
            loss = criterion(model, data, target)
            #loss.backward()
            return loss

        # forward pass
        loss, s_stype = optimizer.step(closure=closure)
        
        train_iter_loss_list.append(loss.item())
        
        step_type.append(s_stype)
        
    end = time.time()
    
    train_loss = compute_loss(model, train_set)
    test_acc = compute_accuracy(model, test_set)
        
        
    train_loss_list.append(train_loss)
    test_acc_list.append(test_acc)
    run_time_list.append(end-begin)
    model_step_ratio_list.append(sum(step_type)/len(step_type))
    learning_rate_list.append(curr_lr)
        
    # Display loss statistics
    print(f'Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}  -  Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
    print('Model steps taken:', sum(step_type), '/', len(step_type))
    
    #print(epoch, end=' ')
    

    
    # Decay learning rate
    if opt_out['autoschedule'] and epoch > 2:
        if sum(step_type)/len(step_type) >= opt_out['gamma']:
            curr_lr *= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
        else:
            curr_lr /= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
    
    
opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'model_step_ratio':model_step_ratio_list,
                 'train_iter_loss':train_iter_loss_list,
                 'learning_rate':learning_rate_list,
                })

now = datetime.datetime.now() # current date and time
date_time = now.strftime("%Y_%m_%d_%H_%M_%S")


filename = results_dir + "{}_{}_{}_epochs_{}_autoschedule_{}_gamma_{}_beta_{}_lr_{}_c_{}_{}.json".format(opt_out['name'], 
                                                                   opt_out['data'], 
                                                                   opt_out['model'], 
                                                                   opt_out['epochs'],
                                                                   opt_out['autoschedule'],
                                                                   opt_out['gamma'],
                                                                   opt_out['beta'],
                                                                   opt_out['lr'], 
                                                                   opt_out['c'], 
                                                                   date_time
                                                                   )

with open(filename, 'w') as f:
    json.dump(opt_out, f)


    


##############################################################################
# Train with SMBi optimizer
##############################################################################


opt_out = {'name':'SMBi', 'autoschedule':False, 'gamma':0.05, 'beta':0.9, 'lr':1, 'c':0.1, 'eta':0.99, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = ResNet([3, 4, 6, 3], num_classes=10)
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss
 
optimizer = SMBi(model.parameters(), lr=opt_out['lr'], c=opt_out['c'], eta=opt_out['eta'])

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))

train_loss_list = []
train_iter_loss_list = []
test_acc_list = []
run_time_list = []
model_step_ratio_list = []
learning_rate_list = []

# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

curr_lr = opt_out['lr']


for epoch in range(1, epochs+1):
    
    step_type = []
        
    begin = time.time()
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
        
        # moves tensors to GPU
        if use_GPU:
            data, target = data.cuda(), target.cuda() 
            
        # create loss closure for sls algorithm
        def closure():
            optimizer.zero_grad()
            loss = criterion(model, data, target)
            #loss.backward()
            return loss

        # forward pass
        loss, s_stype = optimizer.step(closure=closure)
        
        train_iter_loss_list.append(loss.item())
        
        step_type.append(s_stype)
        
    end = time.time()
    
    train_loss = compute_loss(model, train_set)
    test_acc = compute_accuracy(model, test_set)
        
        
    train_loss_list.append(train_loss)
    test_acc_list.append(test_acc)
    run_time_list.append(end-begin)
    model_step_ratio_list.append(sum(step_type)/len(step_type))
    learning_rate_list.append(curr_lr)
        
    # Display loss statistics
    print(f'Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}  -  Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
    print('Model steps taken:', sum(step_type), '/', len(step_type))
    
    #print(epoch, end=' ')
    
    
    # Decay learning rate
    if opt_out['autoschedule'] and epoch > 2:
        if sum(step_type)/len(step_type) >= (opt_out['gamma']/2):
            curr_lr *= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
        else:
            curr_lr /= opt_out['beta']
            update_lr(optimizer, curr_lr)
            print("\n Now the learning rate is:", curr_lr, "\n")
    
    
opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'model_step_ratio':model_step_ratio_list,
                 'train_iter_loss':train_iter_loss_list,
                 'learning_rate':learning_rate_list,
                })

now = datetime.datetime.now() # current date and time
date_time = now.strftime("%Y_%m_%d_%H_%M_%S")


filename = results_dir + "{}_{}_{}_epochs_{}_autoschedule_{}_gamma_{}_beta_{}_lr_{}_c_{}_{}.json".format(opt_out['name'], 
                                                                   opt_out['data'], 
                                                                   opt_out['model'], 
                                                                   opt_out['epochs'],
                                                                   opt_out['autoschedule'],
                                                                   opt_out['gamma'],
                                                                   opt_out['beta'],
                                                                   opt_out['lr'], 
                                                                   opt_out['c'], 
                                                                   date_time
                                                                   )

with open(filename, 'w') as f:
    json.dump(opt_out, f)






##############################################################################
# Train with SLS optimizer
##############################################################################


opt_out = {'name':'SLS', 'lr':1, 'c':0.1, 'reset_option':1, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = ResNet([3, 4, 6, 3], num_classes=10)
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss


optimizer = sls.Sls(model.parameters(), 
                    init_step_size=opt_out['lr'], 
                    reset_option=opt_out['reset_option'], 
                    c=opt_out['c'], 
                    n_batches_per_epoch=n_batches_per_epoch
                   )

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))


train_loss_list = []
train_iter_loss_list = []
test_acc_list = []
run_time_list = []

loss = None

for epoch in range(1, epochs+1):
        
    begin = time.time()
    
    # training steps
    model.train()
    
    for batch_index, (data, target) in enumerate(train_loader):
        
        # moves tensors to GPU
        if use_GPU:
            data, target = data.cuda(), target.cuda() 
            
        # create loss closure for sls algorithm
        closure = lambda :  criterion(model, data, target)  
        # clears gradients
        optimizer.zero_grad()
        
        loss = optimizer.step(closure=closure)
        
        train_iter_loss_list.append(loss.item())
        
    end = time.time()
    
    train_loss = compute_loss(model, train_set)
    test_acc = compute_accuracy(model, test_set)
        
    train_loss_list.append(train_loss)
    test_acc_list.append(test_acc)
    run_time_list.append(end-begin)
        
    # Display loss statistics
    print(f'Epoch: {epoch}   -   Training Loss: {round(train_loss, 6)}  -  Test Accuracy: {round(test_acc, 6)}  -  Time: {round(end-begin, 2)}')
    
    #print(epoch, end=' ')
    
opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'train_iter_loss':train_iter_loss_list,
                })


filename = results_dir + "{}_{}_{}_epochs_{}_lr_{}_c_{}_{}.json".format(opt_out['name'], 
                                                                   opt_out['data'], 
                                                                   opt_out['model'], 
                                                                   opt_out['epochs'], 
                                                                   opt_out['lr'], 
                                                                   opt_out['c'], 
                                                                   datetime.datetime.now())

with open(filename, 'w') as f:
    json.dump(opt_out, f)
    



##############################################################################
# Train with Adam optimizer
##############################################################################

opt_out = {'name':'Adam', 'lr':0.001, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 
           'GPU':use_GPU}

# model
model = ResNet([3, 4, 6, 3], num_classes=10)
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss

# optimizer
optimizer = optim.Adam(model.parameters(), lr = opt_out['lr'])


print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))


train_loss_list, test_acc_list, run_time_list, train_iter_loss_list = train_test_network(epochs, train_loader, train_set, test_set, model, criterion, optimizer, use_GPU)


opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'train_iter_loss':train_iter_loss_list,
                })


filename = results_dir + "{}_{}_{}_epochs_{}_lr_{}_{}.json".format(opt_out['name'], opt_out['data'], opt_out['model'], opt_out['epochs'], opt_out['lr'], datetime.datetime.now())

with open(filename, 'w') as f:
    json.dump(opt_out, f)






##############################################################################
# Train with SGD optimizer
##############################################################################

opt_out = {'name':'SGD', 'autoschedule':True, 'gamma':0.05, 'beta':0.9, 'lr':0.1, 'epochs':epochs, 
           'data':'CIFAR10', 'model':'ResNet34', 'loss_func':'Softmax', 'GPU':use_GPU}

# model
model = ResNet([3, 4, 6, 3], num_classes=10)
if use_GPU:
    model.cuda() 
 
# loss function
criterion = softmax_loss

# optimizer
optimizer = optim.SGD(model.parameters(), lr = opt_out['lr'])

print('\n' + 'Starting to train with {} optimizer: For {} epochs'.format(opt_out['name'], opt_out['epochs']))


train_loss_list, test_acc_list, run_time_list, train_iter_loss_list = train_test_network(epochs, train_loader, train_set, test_set, model, criterion, optimizer, use_GPU)


opt_out.update({'train_loss':train_loss_list,
                 'test_acc':test_acc_list,
                 'run_time':run_time_list,
                 'train_iter_loss':train_iter_loss_list,
                })


filename = results_dir + "{}_{}_{}_epochs_{}_lr_{}_{}.json".format(opt_out['name'], opt_out['data'], opt_out['model'], opt_out['epochs'], opt_out['lr'], datetime.datetime.now())

with open(filename, 'w') as f:
    json.dump(opt_out, f)
    

    

 

