import os
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pickle
import sys
import shutil
import copy
import datetime

from models.attack_model_base import AttackModel


## code taken from https://github.com/mathcbc/advGAN_pytorch

class Generator(nn.Module):
    def __init__(self,
                 gen_input_nc,
                 image_nc,
                 ):
        super(Generator, self).__init__()

        encoder_lis = [
            # MNIST:1*28*28
            nn.Conv2d(gen_input_nc, 8, kernel_size=3, stride=1, padding=0, bias=True),
            nn.InstanceNorm2d(8),
            nn.ReLU(),
            # 8*26*26
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0, bias=True),
            nn.InstanceNorm2d(16),
            nn.ReLU(),
            # 16*12*12
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=True),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            # 32*5*5
        ]

        bottle_neck_lis = [ResnetBlock(32),
                       ResnetBlock(32),
                       ResnetBlock(32),
                       ResnetBlock(32),]

        decoder_lis = [
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=0, bias=False),
            nn.InstanceNorm2d(16),
            nn.ReLU(),
            # state size. 16 x 11 x 11
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=0, bias=False),
            nn.InstanceNorm2d(8),
            nn.ReLU(),
            # state size. 8 x 23 x 23
            nn.ConvTranspose2d(8, image_nc, kernel_size=6, stride=1, padding=0, bias=False),
            nn.Tanh()
            # state size. image_nc x 28 x 28
        ]

        self.encoder = nn.Sequential(*encoder_lis)
        self.bottle_neck = nn.Sequential(*bottle_neck_lis)
        self.decoder = nn.Sequential(*decoder_lis)

    def forward(self, x):
        x = self.encoder(x)
        x = self.bottle_neck(x)
        x = self.decoder(x)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out


class Discriminator(nn.Module):
    def __init__(self, image_nc):
        super(Discriminator, self).__init__()
        # MNIST: 1*28*28
        model = [
            nn.Conv2d(image_nc, 8, kernel_size=4, stride=2, padding=0, bias=True),
            nn.LeakyReLU(0.2),
            # 8*13*13
            nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2),
            # 16*5*5
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 1, 1),
            nn.Sigmoid()
            # 32*1*1
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        output = self.model(x).squeeze()
        return output


class AdvGAN(AttackModel):
    def __init__(self, defender, args):
        # self.defender_loss = None
        # self.reconstruction_loss = None
        print("Initializing AdvGAN....")
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
        self.model_num_labels = args.num_labels
        self.defender = defender
        self.input_nc = args.num_channels
        self.output_nc = args.num_channels
        self.box_min = args.min_clamp
        self.box_max = args.max_clamp
        self.epsilon = args.adv_epsilon
        self.requires_training = False
        self.image_nc = self.input_nc

        self.gen_input_nc = args.num_channels
        self.netG = Generator(self.gen_input_nc, self.image_nc).to(self.device)
        self.netDisc = Discriminator(self.image_nc).to(self.device)

        # initialize all weights
        self.netG.apply(self.weights_init)
        self.netDisc.apply(self.weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=args.adv_lr)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(), lr=args.adv_lr)
        print("AdvGAN init complete!")


    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    
    def clone(self, requires_grad=False):
        """
        Return a cloned model of yourself.

        Args:
            requires_grad (bool, optional): Should the cloned attack model be trainable. Defaults to False.
        """
        raise NotImplementedError()
    
    def set_base_model(self, model):
        self.model = model
    
    def get_loss(self, points, labels, requires_mean=True, save_losses=False):
        """
        Get loss of attack model on the set of data points.

        Args:
            points (PyTorch Tensor): Tensor of input points
            labels (PyTorch Tensor): Tensor of corresponding labels
            requires_mean (bool, optional): If true, returns the mean else vector of pointwise losses. Defaults to True.
            save_losses (bool, optional): If true, save reconstrucion and defender losses. Defaults to False.
        """
        raise NotImplementedError()

    
    def train_advgan(self, dataset, args, epochs):
        """
        Train the current model on the set of points, labels pair
        
        Args:
            points (PyTorch Tensor): Tensor of input points.
            labels (PyTorch Tensor): Tensor of corresponding labels.
            epochs (int): The number of epochs for which to train. Model might have a self.epochs field.
        """
        loss_D_sum = 0
        loss_G_fake_sum = 0
        loss_perturb_sum = 0
        loss_adv_sum = 0

        batch_size = args.adv_gan_batch_size
        generator = torch.Generator(self.device)
        train_dl = DataLoader(dataset, batch_size=batch_size, shuffle=True, generator=generator)

        model_path = os.path.join(args.save_dir, 'AdvGAN')
        os.makedirs(model_path, exist_ok=True)

        for epoch in range(1, epochs+1):    
            for i, data in enumerate(train_dl, start=0):
                images, labels = data
                images, labels = images.to(self.device), labels.to(self.device)

                loss_D_batch, loss_G_fake_batch, loss_perturb_batch, loss_adv_batch = self.train_batch(images, labels)
                loss_D_sum += loss_D_batch
                loss_G_fake_sum += loss_G_fake_batch
                loss_perturb_sum += loss_perturb_batch
                loss_adv_sum += loss_adv_batch

            num_batch = len(train_dl)
            print("epoch %d:\nloss_D: %.3f, loss_G_fake: %.3f,\nloss_perturb: %.3f, loss_adv: %.3f, \n" %
            (epoch, loss_D_sum/num_batch, loss_G_fake_sum/num_batch,
            loss_perturb_sum/num_batch, loss_adv_sum/num_batch))
            
            if epoch%args.save_freq==0:
                netG_file_path = model_path + 'netG_epoch_' + str(epoch) + '.pth'
                torch.save(self.netG.state_dict(), netG_file_path)

    def train_batch(self, x, labels):
        # optimize D
        # for i in range(1):
        self.netG.train()
        perturbation = self.netG(x)

        # add a clipping trick
        adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
        adv_images = torch.clamp(adv_images, self.box_min, self.box_max)

        self.optimizer_D.zero_grad()
        pred_real = self.netDisc(x)
        loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))
        loss_D_real.backward()

        pred_fake = self.netDisc(adv_images.detach())
        loss_D_fake = F.mse_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device))
        loss_D_fake.backward()
        loss_D_GAN = loss_D_fake + loss_D_real
        self.optimizer_D.step()

        # optimize G
        # for i in range(1):
        self.optimizer_G.zero_grad()

        # cal G's loss in GAN
        pred_fake = self.netDisc(adv_images)
        loss_G_fake = F.mse_loss(pred_fake, torch.ones_like(pred_fake, device=self.device))
        loss_G_fake.backward(retain_graph=True)

        # calculate perturbation norm
        C = 0.1
        loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))
        # loss_perturb = torch.max(loss_perturb - C, torch.zeros(1, device=self.device))

        # cal adv loss
        logits_model = self.model(adv_images)
        probs_model = F.softmax(logits_model, dim=1)
        onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels]

        # C&W loss function
        real = torch.sum(onehot_labels * probs_model, dim=1)
        other, _ = torch.max((1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1)
        zeros = torch.zeros_like(other)
        loss_adv = torch.max(real - other, zeros)
        loss_adv = torch.sum(loss_adv)

        # maximize cross_entropy loss
        # loss_adv = -F.mse_loss(logits_model, onehot_labels)
        # loss_adv = - F.cross_entropy(logits_model, labels)

        adv_lambda = 10
        pert_lambda = 1
        loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb
        loss_G.backward()
        self.optimizer_G.step()

        return loss_D_GAN.item(), loss_G_fake.item(), loss_perturb.item(), loss_adv.item()
    
    def get_perturbed(self, X, y=None):
        """
        Get the perturbed points for the given set of points

        Args:
            points (PyTorch Tensor): Tensor of input points.
        """
        self.netG.eval()
        X = X.to(self.device)
        perturbation = self.netG(X)
        perturbation = torch.clamp(perturbation, -self.epsilon, self.epsilon)
        adv_img = perturbation + X
        adv_img = torch.clamp(adv_img, 0, 1)
        return adv_img


    def indices_to_points(self, indices):
        """
        Utility function: Takes in indices and returns corresponding dataset slice.
        
        Args:
            indices (List): List of indices. Could also be a numpy array.

        Returns:
            X, Y: A slice of the dataset, indexed by the input indices. 
        """
        return self.dataset[indices]
    
if __name__=="__main__":
    print("Testing AdvGAN")
    # Train on MNIST!