import sys
sys.path.append("../../models/cifar10/")
sys.path.append("../../")
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import time
import numpy as np
import math
import config as cf

from utils.attacks import Adversary

# available models
from models import Wide_ResNet
from models import resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202

from utils.utils import *


class Trainer:
    def __init__(self, cf):
        self.model_type = cf.model_type
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.train_batch_size = cf.train_batch_size
        self.test_batch_size = cf.test_batch_size
        self.num_epochs = cf.num_epochs
        self.start_epoch = 0
        self.curr_epoch = None
        self.curr_batch = None
        self.learning_rate = cf.lr
        self.momentum = 0.9
        self.weight_decay = cf.weight_decay
        self.training_type = cf.training_type
        self.variance = cf.variance
        self.adversary = Adversary(cf.attack, self.device)
        self.adv_num_examples = 1
        self.adv_step_size = 0.25
        self.adv_max_iter = 20
        self.model = None
        self.optimizer = None
        self.lr_scheduler = None
        self.criterion = None
        self.train_loader = None
        self.test_loader = None

        self.save_file_model = cf.save_file_model
        self.writer = SummaryWriter()
        self.save_path_model = ''
        
        self.transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(cf.mean['cifar10'], cf.std['cifar10']),
                ])
        self.transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cf.mean['cifar10'], cf.std['cifar10']),
            ])


    def save_model(self):
        torch.save(self.model.state_dict(), self.save_path_model+self.save_file_model+'.pth')


    def run(self):
        self.load_model()
        self.load_data()
        self.train()
        self.save_model()
        print('training successful!')
        return
    
    
    def reset(self):
        self.load_model()
        self.load_data()
        self.curr_epoch = None
        self.curr_batch = None


    def load_data(self):
        self.train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform_train)     
        self.test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=self.transform_test)
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True, num_workers=2)
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=100, shuffle=False, num_workers=2)


    def load_model(self):
        # available models
        if self.model_type == 'wrn':
            self.model = Wide_ResNet(28, 10, 0.3, 10).to(self.device)
        elif self.model_type == 'rn20':
            self.model = resnet20().to(self.device)
        elif self.model_type == 'rn32':
            self.model = resnet32().to(self.device)
        elif self.model_type == 'rn44':
            self.model = resnet44().to(self.device)
        elif self.model_type == 'rn56':
            self.model = resnet56().to(self.device)
        elif self.model_type == 'rn110':
            self.model = resnet110().to(self.device)
        elif self.model_type == 'rn1202':
            self.model = resnet1202().to(self.device)
        else:
            raise 'please refer to one of the available models (see utils/config.py)'
        
        test = self.model(Variable(torch.randn(1,3,32,32).to(self.device)))
        print('test size: ', test.size())
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), 
                self.learning_rate, momentum = self.momentum, weight_decay=self.weight_decay)
        self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[100, 150],
                                                                 last_epoch=self.start_epoch - 1)
        
    def train(self):
        elapsed_time = 0

        for self.curr_epoch in range(1, self.num_epochs+1):
            self.model.train()
            #self.model.training = True
            train_loss = 0
            train_correct = 0
            total = 0
            time_start = time.time()

            print('\n=> Training Epoch #%d, LR=%.4f' %(self.curr_epoch, self.optimizer.param_groups[0]['lr'])) 
                
            for self.curr_batch, (x, y) in enumerate(self.train_loader):
                x, y = x.to(self.device), y.to(self.device)
                # perturb data during noisy training
                if self.training_type == 'noisy':
                    x += torch.randn(x.size()).to(self.device) * self.variance
                x, y = Variable(x), Variable(y)
                self.optimizer.zero_grad()
                outputs = self.model(x)
                total += y.size(0)
                loss = self.criterion(outputs, y)
                train_loss += loss
                _, pred = torch.max(outputs.data, 1)
                train_correct += pred.eq(y.data).cpu().sum()
                loss.backward()
                self.optimizer.step()

                # add training on adversarial perturbation during adv training
                if self.training_type == 'adversarial':
                    if self.adversary.strategy == 'random_walk':
                        x, y = self.adversary.get_adversarial_examples(
                                self.model, x, y,
                                step_size=self.adv_step_size,
                                num_examples=1,
                                max_iter=self.adv_max_iter)
                    else:
                        delta = self.adversary.get_adversarial_examples(
                                self.model, x, y,
                                step_size=self.adv_step_size,
                                num_examples=self.adv_num_examples,
                                max_iter=self.adv_max_iter)
                        x = x + delta
                    x, y = Variable(x).to(self.device), Variable(y).to(self.device)
                    outcome = self.model(x)
                
                    _, pred = torch.max(outcome.data, 1)
                    train_correct += pred.eq(y.data).cpu().sum()
                    total += y.size(0)
                    loss = self.criterion(outcome, y)
                    train_loss += loss
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                sys.stdout.write('\r')
                sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
                %(self.curr_epoch, self.num_epochs, self.curr_batch,
                    (len(self.train_dataset)//self.train_batch_size)+1, train_loss.item(), 100.*train_correct/total))
                sys.stdout.flush()            
            
            self.lr_scheduler.step()

            train_acc = 100.*train_correct/total
            
            with torch.no_grad():
                # testing
                self.model.eval()
                self.training = False
                test_loss = 0.
                test_correct = 0
                total = 0
                for self.curr_batch, (x, y) in enumerate(self.test_loader):
                    x_var, y_var = Variable(x), Variable(y)
                    x_var, y_var = x_var.to(self.device), y_var.to(self.device)
                    outcome = self.model(x_var)
                    loss = self.criterion(outcome, y_var)
                    test_loss += loss
                    _, pred = torch.max(outcome.data, 1)
                    test_correct += pred.eq(y_var.data).cpu().sum()
                    total += y_var.size(0)
           
                test_acc = 100.*test_correct/total
                print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(self.curr_epoch, test_loss.item(), test_acc))
            
            time_epoch = time.time() - time_start
            elapsed_time += time_epoch
            print('| Elapsed time : %d:%02d:%02d' %(cf.get_hms(elapsed_time)))
            self.write_tb(train_loss.item(), train_acc, test_loss.item(), test_acc)


    def write_tb(self, train_loss, train_correct, test_loss, test_correct):
        self.writer.add_scalar('Loss/train', train_loss, self.curr_epoch)
        self.writer.add_scalar('Loss/test', test_loss, self.curr_epoch)
        self.writer.add_scalar('Accuracy/train', train_correct, self.curr_epoch)
        self.writer.add_scalar('Accuracy/test', test_correct, self.curr_epoch)


trainer = Trainer(cf)
trainer.run()


