import logging
import os
import time
import torch
from opacus.utils.batch_memory_manager import BatchMemoryManager
from torch import nn
from opacus import PrivacyEngine

from mia.vanilla import VanillaMia
from representation.aug.aug_selector import get_aug
from representation.loss.relaxloss import RelaxLoss
from conf import settings
from dataloader.info import num_classes

from util.averagemeters import AverageMeter
from util.metrics import accuracy
from util.yaml_loader import load_loss_conf


class VanillaMia_DP_SGD(VanillaMia):
    def __init__(
            self,
            target_model, shadow_models, attack_model,
            target_train, target_test, shadow_train_list, shadow_test_list,
            target_ref=None, shadow_ref=None
    ):
        super(VanillaMia_DP_SGD, self).__init__(
            target_model, shadow_models, attack_model,
            target_train, target_test, shadow_train_list, shadow_test_list,
            target_ref, shadow_ref
        )
        self.producer = None
        self.loss_confs = None

    def load_components(self, args):
        from representation.aug.aug_selector import get_aug
        self.producer = get_aug(args)
        self.loss_confs = load_loss_conf(args.loss_conf)

    def train_target_model(self, args):
        save_path = args.save_path = os.path.join(args.save_folder, args.arch, args.arch_conf)
        ''' logger '''
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        args.logger_file = os.path.join(save_path, 'log_{}_target.txt'.format(args.cmd))
        handler = logging.FileHandler(args.logger_file, mode='w')
        formatter = logging.Formatter('%(asctime)s:%(message)s')
        handler.setFormatter(formatter)
        logger = logging.getLogger('logger_target')
        logger.setLevel(logging.INFO)
        logger.addHandler(handler)
        logger.addHandler(logging.StreamHandler())
        """ load aug """
        self.data_aug = get_aug(args=args)
        """ load epoch conf """
        args = self.load_epoch_conf(args, model_type='target', option='train')
        """ put it on multi-gpu """
        model = self.target_model.to(args.device)
        ''' others '''
        criterion = nn.CrossEntropyLoss().to(args.device)
        ce = nn.CrossEntropyLoss().to(args.device)
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay
        )
        privacy_engine = PrivacyEngine()
        model, optimizer, data_loader = privacy_engine.make_private(
            module=model,
            optimizer=optimizer,
            data_loader=self.target_train,
            noise_multiplier=self.loss_confs['sigma'],
            max_grad_norm=self.loss_confs['max_norm'],
        )
        ''' training '''
        best_prec1 = 0
        cur_prec1 = 0
        step = 0
        with BatchMemoryManager(
                data_loader=data_loader,
                max_physical_batch_size=args.batch_size // 16,
                optimizer=optimizer
        ) as new_data_loader:
            for current_epoch in range(args.start_epoch, args.epoch):
                optimizer = self.adjust_learning_rate(args.mile_stone, optimizer, current_epoch, logger)
                self.classification_train(
                    args,
                    model=model, dataloader=new_data_loader,
                    optimizer=optimizer, criterion=criterion,
                    current_epoch=current_epoch, logger=logger
                )
                prec1 = self.validate(
                    args,
                    test_loader=self.target_test, model=model,
                    criterion=ce, logger=logger
                )
                cur_prec1 = prec1
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                logger.info(' * best \t{:.3f}\t'.format(best_prec1.item()))
                torch.cuda.empty_cache()
        # save to disk
        final_path = os.path.join(save_path, 'target')
        if not os.path.exists(final_path):
            os.makedirs(final_path)
        final_path = os.path.join(final_path, 'model_latest_{}.path.tar'.format('target'))
        torch.save(model.cpu(), final_path)
        print("-----------------------Target End---------------------")

    def train_shadow_model(self, args, idx):
        save_path = args.save_path = os.path.join(args.save_folder, args.arch, args.arch_conf)
        ''' logger '''
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        args.logger_file = os.path.join(save_path, 'log_{}_shadow_{}.txt'.format(args.cmd, idx))
        handler = logging.FileHandler(args.logger_file, mode='w')
        formatter = logging.Formatter('%(asctime)s:%(message)s')
        handler.setFormatter(formatter)
        logger = logging.getLogger('logger_shadow_{}'.format(idx))
        logger.setLevel(logging.INFO)
        logger.addHandler(handler)
        logger.addHandler(logging.StreamHandler())
        """ load aug """
        self.data_aug = get_aug(args=args)
        """ load epoch conf """
        args = self.load_epoch_conf(args, model_type='shadow', option='train')
        ''' load model and dataloader '''
        model = self.shadow_models[idx].to(args.device)
        train_loader = self.shadow_train_list[idx]
        test_loader = self.shadow_test_list[idx]
        ''' others '''
        criterion = nn.CrossEntropyLoss().to(args.device)
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay
        )
        privacy_engine = PrivacyEngine()
        model, optimizer, data_loader = privacy_engine.make_private(
            module=model,
            optimizer=optimizer,
            data_loader=train_loader,
            noise_multiplier=self.loss_confs['sigma'],
            max_grad_norm=self.loss_confs['max_norm'],
        )
        ce = nn.CrossEntropyLoss().to(args.device)
        ''' training '''
        best_prec1 = 0
        cur_prec1 = 0
        step = 0
        with BatchMemoryManager(
                data_loader=data_loader,
                max_physical_batch_size=args.batch_size // 16,
                optimizer=optimizer
        ) as new_data_loader:
            for current_epoch in range(args.start_epoch, args.epoch):
                optimizer = self.adjust_learning_rate(args.mile_stone, optimizer, current_epoch, logger)
                self.classification_train(
                    args,
                    model=model, dataloader=new_data_loader,
                    optimizer=optimizer, criterion=criterion,
                    current_epoch=current_epoch, logger=logger
                )
                prec1 = self.validate(
                    args,
                    test_loader=test_loader, model=model,
                    criterion=ce, logger=logger
                )
                cur_prec1 = prec1
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                logger.info(' * best \t{:.3f}\t'.format(best_prec1.item()))
                torch.cuda.empty_cache()
        # save to disk
        final_path = os.path.join(save_path, 'shadow')
        if not os.path.exists(final_path):
            os.makedirs(final_path)
        final_path = os.path.join(final_path, 'model_latest_{}_{}.path.tar'.format('shadow', idx))
        torch.save(model.cpu(), final_path)

    def classification_train(self, args, model, dataloader, optimizer, criterion, current_epoch, logger):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        total_losses = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        end = time.time()
        i = 0

        for inputs, target in dataloader:
            data_time.update(time.time() - end)
            target = target.view(-1).long().to(args.device)
            inputs = inputs.to(args.device)
            inputs = self.data_aug(inputs)
            output, fea = model(inputs)
            loss = criterion(output, target)
            total_loss = loss
            loss.item()
            losses.update(loss.item(), inputs.size(0))
            total_losses.update(total_loss.item(), inputs.size(0))
            prec1 = accuracy(output.data, target, topk=(1,))
            top1.update(prec1[0], inputs.size(0))
            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            batch_time.update(time.time() - end)
            end = time.time()
            if i % args.print_freq == 0:
                logger.info("Epoch: [{0}]\t"
                            "Iter: [{1}]\t"
                            "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                            "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                            "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
                            "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format(
                    current_epoch,
                    i,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=total_losses,
                    top1=top1)
                )
            i+=1


if __name__ == '__main__':
    a = VanillaMia(None, None, None, None, None, None, None)
