import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR

torch.manual_seed(123)

parser = argparse.ArgumentParser(description='PyTorch MNIST WAE-MMD')
parser.add_argument('-batch_size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 100)')
parser.add_argument('-batch_size_test', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 100)')
parser.add_argument('-epochs', type=int, default=50, help='number of epochs to train (default: 100)')
parser.add_argument('-lr', type=float, default=0.001, help='learning rate (default: 0.0001)')
parser.add_argument('-dim_h', type=int, default=128, help='hidden dimension (default: 128)')
parser.add_argument('-n_z', type=int, default=8, help='hidden dimension of z (default: 8)')
parser.add_argument('-LAMBDA', type=float, default=10, help='regularization coef MMD term (default: 10)')
parser.add_argument('-n_channel', type=int, default=1, help='input channels (default: 1)')
parser.add_argument('-sigma', type=float, default=1, help='variance of hidden dimension (default: 1)')
args = parser.parse_args()

trainset = MNIST(root='./data/',
                 train=True,
                 transform=transforms.ToTensor(),
                 download=True)

testset = MNIST(root='./data/',
                train=False,
                transform=transforms.ToTensor(),
                download=True)

train_loader = DataLoader(dataset=trainset,
                          batch_size=args.batch_size,
                          shuffle=True)

test_loader = DataLoader(dataset=testset,
                         batch_size=args.batch_size_test,
                         shuffle=False)


def free_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True


def frozen_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False


class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()

        self.n_channel = args.n_channel
        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.main = nn.Sequential(
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
        )
        self.fc = nn.Linear(self.dim_h * (2 ** 3)+10, self.n_z)

    def forward(self, x,y_emd):
        x = self.main(x)
        x = x.squeeze()
        x = self.fc(torch.cat((x, y_emd), dim=1))
        return x


class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()

        self.n_channel = args.n_channel
        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.proj = nn.Sequential(
            nn.Linear(self.n_z+10, self.dim_h * 8 * 7 * 7),
            nn.ReLU()
        )

        self.main = nn.Sequential(
            nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, 4),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, 4),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 2, 1, 4, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x, y_emd):
        x = self.proj(torch.cat((x, y_emd), dim=1))
        x = x.view(-1, self.dim_h * 8, 7, 7)
        x = self.main(x)
        return x


def imq_kernel(X: torch.Tensor,
               Y: torch.Tensor,
               h_dim: int):
    batch_size = X.size(0)

    norms_x = X.pow(2).sum(1, keepdim=True)  # batch_size x 1
    prods_x = torch.mm(X, X.t())  # batch_size x batch_size
    dists_x = norms_x + norms_x.t() - 2 * prods_x

    norms_y = Y.pow(2).sum(1, keepdim=True)  # batch_size x 1
    prods_y = torch.mm(Y, Y.t())  # batch_size x batch_size
    dists_y = norms_y + norms_y.t() - 2 * prods_y

    dot_prd = torch.mm(X, Y.t())
    dists_c = norms_x + norms_y.t() - 2 * dot_prd

    stats = 0
    for scale in [.1, .2, .5, 1., 2., 5., 10.]:
        C = 2 * h_dim * 1.0 * scale
        res1 = C / (C + dists_x)
        res1 += C / (C + dists_y)

        if torch.cuda.is_available():
            res1 = (1 - torch.eye(batch_size).cuda()) * res1
        else:
            res1 = (1 - torch.eye(batch_size)) * res1

        res1 = res1.sum() / (batch_size - 1)
        res2 = C / (C + dists_c)
        res2 = res2.sum() * 2. / (batch_size)
        stats += res1 - res2

    return stats


def rbf_kernel(X: torch.Tensor,
               Y: torch.Tensor,
               h_dim: int):
    batch_size = X.size(0)

    norms_x = X.pow(2).sum(1, keepdim=True)  # batch_size x 1
    prods_x = torch.mm(X, X.t())  # batch_size x batch_size
    dists_x = norms_x + norms_x.t() - 2 * prods_x

    norms_y = Y.pow(2).sum(1, keepdim=True)  # batch_size x 1
    prods_y = torch.mm(Y, Y.t())  # batch_size x batch_size
    dists_y = norms_y + norms_y.t() - 2 * prods_y

    dot_prd = torch.mm(X, Y.t())
    dists_c = norms_x + norms_y.t() - 2 * dot_prd

    stats = 0
    for scale in [.1, .2, .5, 1., 2., 5., 10.]:
        C = 2 * h_dim * 1.0 / scale
        res1 = torch.exp(-C * dists_x)
        res1 += torch.exp(-C * dists_y)

        if torch.cuda.is_available():
            res1 = (1 - torch.eye(batch_size).cuda()) * res1
        else:
            res1 = (1 - torch.eye(batch_size)) * res1

        res1 = res1.sum() / (batch_size - 1)
        res2 = torch.exp(-C * dists_c)
        res2 = res2.sum() * 2. / batch_size
        stats += res1 - res2

    return stats


encoder, decoder = Encoder(args), Decoder(args)
# encoder.load_state_dict(torch.load('wae_encoder_cond.pth'))
# decoder.load_state_dict(torch.load('wae_decoder_cond.pth'))
criterion = nn.MSELoss()

encoder.train()
decoder.train()

if torch.cuda.is_available():
    encoder, decoder = encoder.cuda(), decoder.cuda()

one = torch.Tensor([1])
mone = one * -1

if torch.cuda.is_available():
    one = one.cuda()
    mone = mone.cuda()

# Optimizers
enc_optim = optim.Adam(encoder.parameters(), lr=args.lr)
dec_optim = optim.Adam(decoder.parameters(), lr=args.lr)

enc_scheduler = StepLR(enc_optim, step_size=20, gamma=0.5)
dec_scheduler = StepLR(dec_optim, step_size=20, gamma=0.5)

for epoch in range(args.epochs):
    step = 0
    for (images, label) in tqdm(train_loader):
        labels_one_hot = torch.nn.functional.one_hot(label,10).type(torch.FloatTensor)
        if torch.cuda.is_available():
            images = images.cuda()
            labels_one_hot = labels_one_hot.cuda()

        enc_optim.zero_grad()
        dec_optim.zero_grad()

        # ======== Train Generator ======== #

        batch_size = images.size()[0]

        z = encoder(images,labels_one_hot)
        x_recon = decoder(z,labels_one_hot)

        recon_loss = criterion(x_recon, images)

        # ======== MMD Kernel Loss ======== #

        z_fake = Variable(torch.randn(images.size()[0], args.n_z) * args.sigma)
        if torch.cuda.is_available():
            z_fake = z_fake.cuda()

        # z_real = encoder(images)

        mmd_loss = imq_kernel(z, z_fake, h_dim=encoder.n_z)
        mmd_loss = mmd_loss / batch_size

        total_loss = recon_loss + mmd_loss
        total_loss.backward()

        enc_optim.step()
        dec_optim.step()

        step += 1

        if (step + 1) % 300 == 0:
            print("Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.4f, MMD Loss %.4f" %
                  (epoch + 1, args.epochs, step + 1, len(train_loader), recon_loss.data.item(),
                   mmd_loss.item()))

    if (epoch + 1) % 10 == 0:
        batch_size = args.batch_size_test
        test_iter = iter(test_loader)
        test_data = next(test_iter)

        y =  torch.nn.functional.one_hot(test_data[1],10).type(torch.FloatTensor).cuda()
        z_real = encoder(Variable(test_data[0]).cuda(),y)
        
        reconst = decoder(z_real,y).cpu().view(batch_size, 1, 28, 28)
        sample = decoder(torch.randn_like(z_real),y).cpu().view(batch_size, 1, 28, 28)

        if not os.path.isdir('./data/reconst_images_t'):
            os.makedirs('data/reconst_images_t')

        save_image(test_data[0].view(-1, 1, 28, 28), './data/reconst_images_t/wae_mmd_input.png')
        save_image(reconst.data, './data/reconst_images_t/wae_mmd_images_%d.png' % (epoch + 1))
        save_image(sample.data, './data/reconst_images_t/wae_mmd_samples_%d.png' % (epoch + 1))

