import torch
import torch.nn as nn
from torch.optim import Adam
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils.weight_norm as weightNorm


def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau
        )


def hard_update(target, source):
    for m1, m2 in zip(target.modules(), source.modules()):
        m1._buffers = m2._buffers.copy()
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)


class TReLU(nn.Module):
    def __init__(self):
        super(TReLU, self).__init__()
        self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
        self.alpha.data.fill_(0)

    def forward(self, x):
        x = F.relu(x - self.alpha) + self.alpha
        return x


class Discriminator(nn.Module):
    def __init__(self, if_patch=False):
        super(Discriminator, self).__init__()
        self._if_patch = if_patch

        self.conv0 = weightNorm(nn.Conv2d(6, 16, 5, 2, 2))
        self.conv1 = weightNorm(nn.Conv2d(16, 32, 5, 2, 2))
        self.conv2 = weightNorm(nn.Conv2d(32, 64, 5, 2, 2))
        self.conv3 = weightNorm(nn.Conv2d(64, 128, 5, 2, 2))
        if self._if_patch:
            self.conv4 = weightNorm(nn.Conv2d(128, 1, 1, 1, 0))
        else:
            self.conv4 = weightNorm(nn.Conv2d(128, 1, 5, 2, 2))
        self.relu0 = TReLU()
        self.relu1 = TReLU()
        self.relu2 = TReLU()
        self.relu3 = TReLU()

    def forward(self, x):
        x = self.conv0(x)
        x = self.relu0(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        if self._if_patch:
            x = x.view(-1, 64)  # Patch Q
        else:
            x = F.avg_pool2d(x, 4)
            x = x.view(-1, 1)
        return x


class WGAN(object):
    def __init__(self, device="cpu", distributed=False, dim=128, if_patch=False):
        self._device = device
        self.distributed = distributed
        self.dim = dim

        self.netD = Discriminator(if_patch=if_patch).to(self._device)
        self.target_netD = Discriminator(if_patch=if_patch).to(self._device)

        for param in self.target_netD.parameters():
            param.requires_grad = False

        hard_update(self.target_netD, self.netD)

        self.optimizerD = Adam(self.netD.parameters(), lr=3e-4, betas=(0.5, 0.999))
        self.LAMBDA = 10  # Gradient penalty lambda hyperparameter

    def cal_gradient_penalty(self, real_data, fake_data, batch_size):
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand(batch_size, int(real_data.nelement() / batch_size)).contiguous()
        alpha = alpha.view(batch_size, 6, self.dim, self.dim)
        alpha = alpha.to(self._device)
        fake_data = fake_data.view(batch_size, 6, self.dim, self.dim)
        interpolates = Variable(alpha * real_data.data + ((1 - alpha) * fake_data.data), requires_grad=True)
        disc_interpolates = self.netD(interpolates)
        gradients = autograd.grad(disc_interpolates, interpolates,
                                  grad_outputs=torch.ones(disc_interpolates.size()).to(self._device),
                                  create_graph=True, retain_graph=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
        return gradient_penalty

    def cal_reward(self, fake_data, real_data):
        return self.target_netD(torch.cat([real_data, fake_data], 1))

    def save_gan(self, path, num_episodes):
        torch.save(self.netD.state_dict(), '{}/wgan_{:05}.pkl'.format(path, num_episodes))

    def load_gan(self, path, map_location, num_episodes):
        self.netD.load_state_dict(torch.load('{}/wgan_{:05}.pkl'.format(path, num_episodes), map_location=map_location))

    def update(self, fake_data, real_data):
        # fake_data: canvas, real_data: gt
        fake_data = fake_data.detach()
        real_data = real_data.detach()

        # standard conditional training for discriminator
        fake = torch.cat([real_data, fake_data], 1)
        real = torch.cat([real_data, real_data], 1)

        # # complement discriminator conditional training for discriminator
        # mask = torch.tensor(random_masks()).float().to(device)
        # fake = torch.cat([(1 - mask) * real_data, mask * fake_data], 1)
        # real = torch.cat([(1 - mask) * real_data, mask * real_data], 1)

        # compute discriminator scores for real and fake data
        D_real = self.netD(real)
        D_fake = self.netD(fake)

        gradient_penalty = self.cal_gradient_penalty(real, fake, real.shape[0])
        self.optimizerD.zero_grad()
        D_cost = D_fake.mean() - D_real.mean() + gradient_penalty
        D_cost.backward()
        self.optimizerD.step()
        soft_update(self.target_netD, self.netD, 0.001)
        return D_fake.mean(), D_real.mean(), gradient_penalty

    def train(self):
        self.netD.train()
        self.target_netD.train()

    def eval(self):
        self.netD.eval()
        self.target_netD.eval()
