import logging
import random
import math
import time

import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score

import torch
from torch import nn
import torch.nn.functional as F


class ADTrainer:
    def __init__(self, training_data, test_data, b_size, device,
                 net, epoch=20, lr=0.01, lr_decay=0.0, weight_decay=1e-4,
                 test_every=1, mu=None, std=None,
                 checkpoint_path="checkpoint.pth.tar"):
        # dataset
        self.training_data = training_data  # a dataloader object
        self.test_data = test_data  # a dataloader object
        self.b_size = b_size

        self.device = device

        # initialize net
        self.net = net

        # add learning rate
        self.lr = lr
        self.lr_decay = lr_decay
        self.weight_decay = weight_decay

        self.mu, self.std = mu, std

        # training related
        self.epoch = epoch
        self.test_every = test_every

        # meter
        # self.train_meter = TrainMeter()
        self.checkpoint_path = checkpoint_path

    def train(self):
        self.net = self.net.to(self.device)
        # self.net.freeze_encoder()

        self.net.train()
        # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()),
        #                              lr=self.lr,
        #                              weight_decay=5e-4, amsgrad=True)

        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.net.parameters()),
                                     lr=self.lr,
                                     weight_decay=5e-4)

        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[300, 600],
                                                         gamma=0.1)  # 150])
        noise_generation_times = []
        epoch_train_times = []
        test_times = []
        best_loss = 100
        for j in range(self.epoch):
            running_loss = 0
            print("===============================")
            print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))

            # step 1: generate pruned dataset for current training (may not do it every epoch)
            loop = tqdm.tqdm(self.training_data, total=len(self.training_data))
            train_num = 0

            # self.train_meter.start_timer()
            epoch_generation_time = 0
            for inputs, targets in loop:

                # noised_inputs, noised_targets = self.generate_random_element_wise_noisy_samples(inputs, self.noise_portion)

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                self.net.zero_grad()
                # training
                pred = self.net(inputs)

                # loss = self.loss_fc(pred, targets)
                loss = F.mse_loss(pred, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # to avoid nan loss
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                running_loss += loss.item()
                train_num += 1

                loop.set_description(f'Epoch [{j}/{self.epoch}]')
                loop.set_postfix(loss=running_loss / train_num)

            logging.info('Full Epoch: {0}\tLoss: {1}'.format(j,
                                                             running_loss / len(self.training_data)))

            if j % self.test_every == 0:
                print('start testing')
                test_time_start = time.perf_counter()
                test_loss = self._test_total_model(j, self.training_data, self.test_data)
                test_times.append(time.perf_counter() - test_time_start)

                if test_loss < best_loss:
                    self.save()

                self.net.train()
            scheduler.step()
            print('test_times:', np.array(test_times).mean(), np.array(test_times).std(),
                  np.array(test_times).shape)

        # self.save()
        # self.full_net = self.full_net.to('cpu')  # to save video memory
    def save(self):
        torch.save(
            {
                'state_dict': self.net.state_dict(),
                'mu': self.mu,
                'std': self.std
            }, self.checkpoint_path)

    def _test_total_model(self, j, train_data, test_data):
        def eval(loader):

            total_loss = 0
            with torch.no_grad():
                self.net.eval()
                for data, target in loader:
                    data = data.to(self.device)
                    target = target.to(self.device)
                    pred = self.net(data)

                    loss = F.mse_loss(pred, target)
                    total_loss += loss.item()
            return total_loss / len(loader)

        train_loss = eval(train_data)
        test_loss = eval(test_data)
        logging.info("################test global model : {}".format(j))
        logging.info(f'train_loss: {train_loss}\t test_loss:{test_loss}')
        return test_loss


if __name__ == '__main__':
    from datasets import get_loader
    from network import VanilaMLP

    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    net = VanilaMLP()
    train_loader, test_loader, mu, std = get_loader('.', 256, normalize=True)

    trainer = ADTrainer(train_loader, test_loader, b_size=256, device='cuda:0',
                 net=net, epoch=800, lr=0.01, lr_decay=0.0, weight_decay=1e-4,
                 test_every=1, mu=mu, std=std,
                 checkpoint_path="kde_auc_net_softmax_sgd-no_pos.pth.tar")

    trainer.train()

