import torch
import torch.nn as nn
import torchvision
import os
import pickle
import scipy.io
import numpy as np
import torch.nn.functional as F
import imageio
from torch.autograd import Variable
from torch import optim
from model import G12, G21
from model import D1, D2
import time
import pdb

class MyLoss(torch.nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, fakeI, realI):
        def ERSMI(I1, I2):
            img_size = I1.shape[0] * I1.shape[1] * I1.shape[2]
            if I2.shape[0] == 1 and I1.shape[0] != 1:
                I2 = I2.repeat(3, 1, 1)

            def kernel_F(y, mu_list, sigma):
                tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).cuda()  # [81, 784]
                tmp_y = y.view(1, -1).repeat(81, 1)
                tmp_y = tmp_mu - tmp_y
                mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
                return mat_L

            mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()

            x_mu_list = mu.repeat(9).view(-1, 81)
            y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)

            mat_K = kernel_F(I1, x_mu_list, 1)
            mat_L = kernel_F(I2, y_mu_list, 1)

            H1 = ((mat_K.mm(mat_K.t())).mul(mat_L.mm(mat_L.t())) / (img_size ** 2)).cuda()
            # h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

            H2 = ((mat_K.mul(mat_L)).mm((mat_K.mul(mat_L)).t()) / img_size).cuda()
            h2 = ((mat_K.sum(1).view(-1, 1)).mul(mat_L.sum(1).view(-1, 1)) / (img_size ** 2)).cuda()
            # h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()

            H2 = 0.5 * H1 + 0.5 * H2
            tmp = H2 + 0.05 * torch.eye(len(H2)).cuda()
            # H2 = H2.double()
            # tmp = tmp.double()
            # print(tmp.device)
            # print(tmp.size())
            # start = time.clock()
            alpha = (tmp.inverse())

            alpha = alpha.mm(h2)
            # end = time.clock()
            # print('2:', end - start)
            ersmi = (2 * (alpha.t()).mm(h2) - ((alpha.t()).mm(H2)).mm(alpha) - 1).squeeze()
            # print(alpha.size())
            return ersmi

        def batch_ERSMI(I1, I2):
            batch_size = I1.shape[0]
            img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
            if I2.shape[1] == 1 and I1.shape[1] != 1:
                I2 = I2.repeat(1,3, 1, 1)

            def kernel_F(y, mu_list, sigma):
                tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size,1,1).cuda()  # [81, 784]
                tmp_y = y.view(batch_size,1, -1).repeat(1,81, 1)
                tmp_y = tmp_mu - tmp_y
                mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
                return mat_L

            mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()

            x_mu_list = mu.repeat(9).view(-1, 81)
            y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)

            mat_K = kernel_F(I1, x_mu_list, 1)
            mat_L = kernel_F(I2, y_mu_list, 1)

            H1 = ((mat_K.matmul(mat_K.transpose(1,2))).mul(mat_L.matmul(mat_L.transpose(1,2))) / (img_size ** 2)).cuda()
            # h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

            H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1,2)) / img_size).cuda()
            h2 = ((mat_K.sum(2).view(batch_size,-1, 1)).mul(mat_L.sum(2).view(batch_size,-1, 1)) / (img_size ** 2)).cuda()
            # h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()

            H2 = 0.5 * H1 + 0.5 * H2
            tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
            # H2 = H2.double()
            # tmp = tmp.double()
            # print(tmp.device)
            # print(tmp.size())
            # start = time.clock()
            alpha = (tmp.inverse())

            alpha = alpha.matmul(h2)
            # end = time.clock()
            # print('2:', end - start)
            ersmi = (2 * (alpha.transpose(1,2)).matmul(h2) - ((alpha.transpose(1,2)).matmul(H2)).matmul(alpha) - 1).squeeze()
            ersmi = -ersmi.mean()
            return ersmi

        # sum = 0
        # batch_size_cur = fakeI.shape[0]
        # for i in range(batch_size_cur):
        #     I1 = fakeI[i, :]
        #     I2 = realI[i, :]
        #     sum += ERSMI(I1, I2)
        # myloss = - sum / batch_size_cur
        # print(myloss.item())
        batch_loss = batch_ERSMI(fakeI,realI)
        # print("batch loss: ",batch_loss.item())
        return batch_loss


class Solver(object):
    def __init__(self, config, svhn_loader, mnist_loader):
        self.svhn_loader = svhn_loader  # A domain
        self.mnist_loader = mnist_loader  # B domain
        self.g12 = None  # g12, d1, g1: input: MNIST
        self.g21 = None
        self.d1 = None
        self.d2 = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.use_reconst_loss = config.use_reconst_loss
        self.use_MIR = config.use_MIR
        self.num_classes = config.num_classes
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.train_iters = config.train_iters
        self.batch_size = config.batch_size
        self.lr = config.lr
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.config = config
        self.criterionMI = MyLoss()
        self.lambda_MI = config.lambda_MI
        self.build_model()

    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g12 = G12(self.config, conv_dim=self.g_conv_dim)
        self.g21 = G21(self.config, conv_dim=self.g_conv_dim)
        self.d1 = D1(conv_dim=self.d_conv_dim)
        self.d2 = D2(self.config,conv_dim=self.d_conv_dim)

        g_params = list(self.g12.parameters()) + list(self.g21.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g12.cuda()
            self.g21.cuda()
            self.d1.cuda()
            self.d2.cuda()

    def merge_images(self, sources, targets, k=10):
        _, _, h, w = sources.shape
        row = int(np.sqrt(self.batch_size))
        merged = np.zeros([3, row * h, row * w * 2])
        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s
            merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
        return merged.transpose(1, 2, 0)

    def to_var(self, x):
        """Converts numpy to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)

    def to_data(self, x):
        """Converts variable to numpy."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data.numpy()

    def reset_grad(self):
        """Zeros the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def as_np(self, data):
        return data.cpu().data.numpy()

    def test(self, svhn_test_loader, mnist_test_loader):
        svhn_iter = iter(self.svhn_loader)
        mnist_iter = iter(self.mnist_loader)
        svhn_test_iter = iter(svhn_test_loader)
        mnist_test_iter = iter(mnist_test_loader)
        index = 0

        # fixed_svhn = self.to_var(svhn_test_iter.next()[index])
        # fixed_mnist = self.to_var(mnist_test_iter.next()[index])
        g12_path = os.path.join(self.model_path, 'g12-%d.pkl' % (40000))
        g21_path = os.path.join(self.model_path, 'g21-%d.pkl' % (40000))
        self.g12 = G12(1, conv_dim=64)
        self.g21 = G21(1, conv_dim=64)
        self.g12.load_state_dict(torch.load(g12_path))
        self.g21.load_state_dict(torch.load(g21_path))

        self.g12.cuda()
        self.g21.cuda()

        for i in range(len(svhn_test_iter)):
            fixed_svhn = self.to_var(svhn_test_iter.next()[index])
            fake_mnist = self.g21(fixed_svhn)

            fake_mnist = self.to_data(fake_mnist)
            svhn = self.to_data(fixed_svhn)

            merged = self.merge_images(svhn, fake_mnist)
            path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (i))
            scipy.misc.imsave(path, merged)
            print('saved %s' % path)

    def train(self, svhn_test_loader, mnist_test_loader):
        svhn_iter = iter(self.svhn_loader)
        mnist_iter = iter(self.mnist_loader)
        iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) - 1

        # fixed mnist and svhn for sampling
        svhn_test_iter = iter(svhn_test_loader)
        mnist_test_iter = iter(mnist_test_loader)
        fixed_svhn = self.to_var(svhn_test_iter.next()[0])
        fixed_mnist = self.to_var(mnist_test_iter.next()[0])

        for step in range(self.train_iters + 1):
            # reset data_iter for each epoch
            if (step + 1) % iter_per_epoch == 0:
                mnist_iter = iter(self.mnist_loader)
                svhn_iter = iter(self.svhn_loader)

            size = 32
            inv_idx = torch.arange(size - 1, -1, -1).long().cuda()

            # load svhn and mnist dataset
            svhn, s_labels = svhn_iter.next()
            svhn, s_labels = self.to_var(svhn), self.to_var(s_labels).long().squeeze()
            mnist, m_labels = mnist_iter.next()
            mnist, m_labels = self.to_var(mnist), self.to_var(m_labels)

            # ============ train D ============#

            # train with real images
            self.reset_grad()
            out = self.d1(mnist)
            d1_loss = torch.mean((out - 1) ** 2)

            out = self.d2(svhn)
            d2_loss = torch.mean((out - 1) ** 2)

            d_mnist_loss = d1_loss
            d_svhn_loss = d2_loss
            d_real_loss = d1_loss + d2_loss
            d_real_loss.backward()
            self.d_optimizer.step()

            # train with fake images
            self.reset_grad()
            fake_svhn = self.g12(mnist)
            out = self.d2(fake_svhn)
            d2_loss = torch.mean(out ** 2)

            fake_mnist = self.g21(svhn)
            out = self.d1(fake_mnist)
            d1_loss = torch.mean(out ** 2)

            d_fake_loss = d1_loss + d2_loss
            d_fake_loss.backward()
            self.d_optimizer.step()

            # ============ train G ============#

            # train mnist-svhn-mnist cycle
            self.reset_grad()
            fake_svhn = self.g12(mnist)
            out_svhn = self.d2(fake_svhn)
            reconst_mnist = self.g21(fake_svhn)

            gen_loss_A = torch.mean((out_svhn - 1) ** 2)
            loss_G_MI = 0
            if self.use_MIR:
                loss_G_MI = self.criterionMI(fake_svhn,mnist)
            loss_G_circle = torch.mean((mnist - reconst_mnist) ** 2)

            g_loss = gen_loss_A + self.lambda_MI * loss_G_MI + loss_G_circle
            g_loss.backward()
            self.g_optimizer.step()

            # train svhn-mnist-svhn cycle
            self.reset_grad()
            fake_mnist = self.g21(svhn)
            out_mnist = self.d1(fake_mnist)
            reconst_svhn = self.g12(fake_mnist)
            loss_G_MI = 0
            if self.use_MIR:
                loss_G_MI = self.criterionMI(svhn, fake_mnist)
            loss_G_circle = torch.mean((svhn - reconst_svhn) ** 2)

            gen_loss_B = torch.mean((out_mnist - 1) ** 2)
            g_loss = gen_loss_B + self.lambda_MI * loss_G_MI + loss_G_circle

            g_loss.backward()
            self.g_optimizer.step()

            # print the log info
            if (step + 1) % self.log_step == 0:
                print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
                      'd_fake_loss: %.4f, gen_loss_A: %.4f, gen_loss_B: %.4f'
                      % (step + 1, self.train_iters, d_real_loss.item(), d_mnist_loss.item(),
                         d_svhn_loss.item(), d_fake_loss.item(), gen_loss_A.item(), gen_loss_B.item()))

            # save the sampled images
            if (step + 1) % self.sample_step == 0:
                fake_svhn = self.g12(fixed_mnist)
                fake_mnist = self.g21(fixed_svhn)

                mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
                svhn, fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)

                merged = self.merge_images(mnist, fake_svhn)
                path = os.path.join(self.sample_path, 'sample-%d-m-s.png' % (step + 1))
                # scipy.misc.imsave(path, merged)
                imageio.imsave(path, merged)
                print('saved %s' % path)

                merged = self.merge_images(svhn, fake_mnist)
                path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1))
                # scipy.misc.imsave(path, merged)
                imageio.imsave(path, merged)
                print('saved %s' % path)

            if (step + 1) % 5000 == 0:
                # save the model parameters for each epoch
                g12_path = os.path.join(self.model_path, 'g12-%d.pkl' % (step + 1))
                g21_path = os.path.join(self.model_path, 'g21-%d.pkl' % (step + 1))
                d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1))
                d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1))
                torch.save(self.g12.state_dict(), g12_path)
                torch.save(self.g21.state_dict(), g21_path)
                torch.save(self.d1.state_dict(), d1_path)
                torch.save(self.d2.state_dict(), d2_path)

