import torch
import torchvision
from torchvision import datasets, transforms
import os
#import util
import sys
sys.path.append('../')
from AGNES import AGNES


class trainer:

    def __init__(self, model, opt_name):
        # opt_name : str
        #    The name of the optimizer, can be one 'AGNES', 'ADAM', 'SGD0.99M', 'SGD', or 'SGD0.9M'
        #    If loading a pre-trained model, opt_name will be replaced by the corresponding opt_name in the saved file

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.use_cuda = torch.cuda.is_available()
        self.net = model.cuda() if self.use_cuda else model
        self.opt_name = opt_name
        self.train_accuracies = []
        self.train_losses =[]
        self.test_accuracies = []
        self.test_losses = []
        self.start_epoch = 0
        exec('self.optimizer ='+opt_name)

    def train(self, save_dir, num_epochs=100, batch_size=50, schedule_lr_epochs=0, lr_factor=1, test_each_epoch=True, verbose=False, seed=None):
        """Trains the network.

        Parameters
        ----------
        save_dir : str
            The directory in which the parameters will be saved
        opt_name : str
            The name of the optimizer, can be one 'AGNES', 'ADAM', 'SGD0.99M', 'SGD', or 'SGD0.9M'
        num_epochs : int
            The number of epochs
        batch_size : int
            The batch size
        learning_rate : float
            The learning rate
        test_each_epoch : boolean
            True: Test the network after every training epoch, False: no testing
        verbose : boolean
            True: Print training progress to console, False: silent mode
        schedule_lr : int
            Number of epochs after which the learning rate (and correction step size for AGNES) is multiplied by a factor of lr_factor
            If schedule_lr==0, then a constant learning rate is used
        lr_factor : float

        """

        # if self.opt_name == 'AGNES':
        #     self.optimizer = AGNES(self.net.parameters(), weight_decay=1e-5)
        # elif self.opt_name == 'ADAM':
        #     self.optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3, weight_decay=1e-5)
        # elif self.opt_name == 'SGD': 
        #     self.optimizer = torch.optim.SGD(self.net.parameters(), lr=1e-3, weight_decay=1e-5)
        # elif self.opt_name == 'SGD0.9M':
        #     self.optimizer = torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-5)
        # elif self.opt_name == 'SGD0.99M':
        #     self.optimizer = torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.99, weight_decay=1e-5)

        if seed is not None:
            torch.manual_seed(seed)

        train_transform = transforms.Compose([
            #util.Cutout(num_cutouts=2, size=8, p=0.8),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=train_transform)
        data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        criterion = torch.nn.CrossEntropyLoss().cuda() if self.use_cuda else torch.nn.CrossEntropyLoss()

        #progress_bar = util.ProgressBar()

        if self.start_epoch==0: #computing test loss and accuracy before the training starts
            test_loss, test_accuracy = self.test(batch_size=batch_size)
            self.test_losses.append(test_loss)
            self.test_accuracies.append(test_accuracy)

        for epoch in range(self.start_epoch + 1, num_epochs + 1):
            print('Epoch {}/{}'.format(epoch, num_epochs))

            #epoch_correct = 0
            #epoch_total = 0
            for i, data in enumerate(data_loader, 1):
                images, labels = data
                images = images.to(self.device)
                labels = labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.net.forward(images)
                loss = criterion(outputs, labels.squeeze_())
                loss.backward()
                self.optimizer.step()

                _, predicted = torch.max(outputs.data, dim=1)
                batch_total = labels.size(0)
                batch_correct = (predicted == labels.flatten()).sum().item()

                self.train_accuracies.append(batch_correct/batch_total)
                self.train_losses.append(loss.item())

                #epoch_total += batch_total
                #epoch_correct += batch_correct

            #     if verbose:
            #         # Update progress bar in console
            #         info_str = 'Last batch accuracy: {:.4f} - Running epoch accuracy {:.4f}'.\
            #                     format(batch_correct / batch_total)
            #         progress_bar.update(max_value=len(data_loader), current_value=i, info=info_str)

            # #self.train_accuracies.append(epoch_correct / epoch_total)
            # if verbose:
            #     progress_bar.new_line()

            if test_each_epoch:
                test_loss, test_accuracy = self.test()
                self.test_losses.append(test_loss)
                self.test_accuracies.append(test_accuracy)
                # if verbose:
                #     print('Test accuracy: {}'.format(test_accuracy))

            if schedule_lr_epochs:
                if epoch%schedule_lr_epochs == 0:
                #update the learning rate every schedule_lr_epochs epochs
                    for g in self.optimizer.param_groups:
                        g['lr'] *= lr_factor #updating the learning rate
                        if 'correction' in g.keys() and not self.opt_name.endswith('fixed'):
                            g['correction'] *= lr_factor #updating the correction step for AGNES

            ### Save parameters after every 10 epochs
            #if epoch%10==0:
        self.save_parameters(epoch, directory=save_dir)

    def test(self, batch_size=250):
        """Tests the network.

        """
        self.net.eval()

        test_transform = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                             ])

        test_dataset = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=test_transform)
        data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        criterion = torch.nn.CrossEntropyLoss().cuda() if self.use_cuda else torch.nn.CrossEntropyLoss()

        correct = 0
        total = 0
        loss = 0
        with torch.no_grad():
            for i, data in enumerate(data_loader, 0):
                images, labels = data
                images = images.to(self.device)
                labels = labels.to(self.device)

                outputs = self.net(images)

                _, predicted = torch.max(outputs, dim=1)
                total += labels.size(0)
                correct += (predicted == labels.flatten()).sum().item()

                loss += criterion(outputs, labels.squeeze_()).item() #loss is normalized by default, but the last batch here would be of a different size

        self.net.train()
        return (loss / (i+1), correct / total)

    def save_parameters(self, epoch, directory):
        """Saves the parameters of the network to the specified directory.

        Parameters
        ----------
        epoch : int
            The current epoch
        directory : str
            The directory to which the parameters will be saved
        """
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'opt_name': self.opt_name,
            'epoch': epoch,
            'model_state_dict': self.net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            #'scheduler_state_dict': self.scheduler.state_dict(),
            'train_accuracies': self.train_accuracies,
            'train_losses': self.train_losses,
            'test_accuracies': self.test_accuracies,
            'test_losses': self.test_losses,
        }, directory +'_'+ str(epoch) + '.pth')

    def load_parameters(self, path):
        """Loads the given set of parameters.

        Parameters
        ----------
        path : str
            The file path pointing to the file containing the parameters
        """
        checkpoint = torch.load(path, map_location=self.device)

        # self.opt_name = checkpoint['opt_name']

        self.net.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_accuracies = checkpoint['train_accuracies']
        self.train_losses = checkpoint['train_losses']
        self.test_accuracies = checkpoint['test_accuracies']
        self.test_losses = checkpoint['test_losses']
        self.start_epoch = checkpoint['epoch']

