import sys
sys.path.append('../../models/mnist/')
sys.path.append('../../utils/')
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import time
import numpy as np

from attacks import Adversary

from models import LeNet5, CNN

# parameters
param = {
    # options are 'LeNet' and 'CNN'
    'model_type': 'CNN',
    'train_batch_size': 64,
    'test_batch_size': 100,
    'num_epochs': 50,
    'learning_rate': 1e-3,
    'weight_decay': 5e-4,
    #training types: 'normal', 'noisy', 'adversarial'
    'training_type': 'normal',
    # available attack strategies: 'fgsm', 'pgd', 'pgd_linf', 'pgd_linf_rand', 'random_walk'
    'attack': 'fgsm',
    # variance for noisy training
    'variance': 0.4
}


class Trainer:
    def __init__(self, param):
        self.model_type = param['model_type']
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.train_batch_size = param['train_batch_size']
        self.test_batch_size = param['test_batch_size']
        self.num_epochs = param['num_epochs']
        self.curr_epoch = None
        self.curr_batch = None
        self.learning_rate = param['learning_rate']
        self.weight_decay = param['weight_decay']
        self.training_type = param['training_type']
        # attacks and data; also provides method 'perturb' to add noise to data
        self.adversary = Adversary(param['attack'], self.device)
        self.adv_step_size = 0.25
        self.adv_max_iter = 100
        self.attack = param['attack']
        self.variance = param['variance']
        self.model = None
        self.optimizer = None
        self.criterion = None
        self.train_loader = None
        self.test_loader = None
        self.writer = SummaryWriter()
        self.model_save_path = './trained_models/'


    def save_model(self):
        torch.save(self.model.state_dict(), self.model_save_path+self.model_type+'.pth')


    def run(self):
        self.load_model()
        self.load_data()
        self.train()
        self.save_model()
        print('model trained and saved!')
        return


    def load_model(self):
        if self.model_type == 'LeNet':
            self.model = LeNet5().to(self.device)
        elif self.model_type == 'CNN':
            self.model = CNN.to(self.device)
        else:
            raise "please select one of the available model LeNet or CNN"
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=self.learning_rate,
                weight_decay=self.weight_decay)


    def load_data(self):
        train_dataset = MNIST(root='../data/',train=True, download=True,
            transform=transforms.ToTensor())
        self.train_loader = torch.utils.data.DataLoader(train_dataset,
            batch_size=self.train_batch_size, shuffle=True)
        test_dataset = MNIST(root='../data/', train=False, download=True,
            transform=transforms.ToTensor())
        self.test_loader = torch.utils.data.DataLoader(test_dataset,
            batch_size=self.test_batch_size, shuffle=True)


    # (also does the testing)
    def train(self):

        # regular training
        self.model.train()

        for self.curr_epoch in range(1, self.num_epochs+1):
            train_loss = 0.
            train_correct = 0
            total = 0

            for self.curr_batch, (x, y) in enumerate(self.train_loader):
                x, y = x.to(self.device), y.to(self.device)
                # perturb data during noisy training
                if self.training_type == 'noisy':
                    perturbation = torch.randn(x.size()).to(self.device)*self.variance
                    x += perturbation
                x_var, y_var = Variable(x).to(self.device), Variable(y).to(self.device)
                outcome = self.model(x_var)
                prediction = torch.max(outcome, 1)
                train_correct += np.sum(prediction[1].cpu().numpy() == y_var.cpu().numpy())
                total += y_var.size(0)
                loss = self.criterion(outcome, y_var)
                train_loss += loss.item()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # add training on adversarial perturbation during adv training
                if self.training_type == 'adversarial':
                    if self.adversary.strategy == 'random_walk':
                        x, y = self.adversary.get_adversarial_examples(
                                self.model, x, y,
                                step_size=self.adv_step_size,
                                num_examples=1,
                                max_iter=self.adv_max_iter)
                    else:
                        delta = self.adversary.get_adversarial_examples(
                                self.model, x, y,
                                step_size=self.adv_step_size,
                                num_examples=self.adv_num_examples,
                                max_iter=self.adv_max_iter)
                        x = x + delta
                    x_var, y_var = Variable(x).to(self.device), Variable(y).to(self.device)
                    outcome = self.model(x_var)
                    prediction = torch.max(outcome, 1)
                    train_correct += np.sum(prediction[1].cpu().numpy() == y_var.cpu().numpy())
                    total += y_var.size(0)
                    loss = self.criterion(outcome, y_var)
                    train_loss += loss.item()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

            train_correct, train_loss = train_correct/total, train_loss/total
            with torch.no_grad():
                # testing
                test_loss = 0.
                test_correct = 0
                total = 0
                for self.curr_batch, (x, y) in enumerate(self.test_loader):
                    #x, y = x.to(self.device), y.to(self.device)
                    x_var, y_var = Variable(x).to(self.device), Variable(y).to(self.device)
                    outcome = self.model(x_var)
                    loss = self.criterion(outcome, y_var)
                    test_loss += loss.item()
                    prediction = torch.max(outcome, 1)
                    test_correct += np.sum(prediction[1].cpu().numpy() == y_var.cpu().numpy())
                    total += y_var.size(0)
                test_correct, test_loss = test_correct / total, test_loss / total

            print('epoch {}/{}, train loss: {}, val loss: {}'.format(
                self.curr_epoch, self.num_epochs, train_loss, test_loss))

            print('train correct: {}, test correct: {}'.format(train_correct, test_correct))
            self.write_tb(train_loss, train_correct, test_loss, test_correct)


    def write_tb(self, train_loss, train_correct, test_loss, test_correct):
        self.writer.add_scalar('Loss/train', train_loss, self.curr_epoch)
        self.writer.add_scalar('Loss/test', test_loss, self.curr_epoch)
        self.writer.add_scalar('Accuracy/train', train_correct, self.curr_epoch)
        self.writer.add_scalar('Accuracy/test', test_correct, self.curr_epoch)


trainer = Trainer(param)
trainer.run()


