import torch
import torch.nn as nn
import torchvision
import os
import pickle
import scipy.io
import imageio
import numpy as np
import torch.nn.functional as F

from torch.autograd import Variable
from torch import optim
from model import G12, G21
from model import D1, D2
import pdb


class MyLoss(torch.nn.Module):
  def __init__(self):
    super(MyLoss, self).__init__()

  def forward(self, fakeI, realI):
    def ERSMI(I1, I2):
      img_size = I1.shape[0] * I1.shape[1] * I1.shape[2]
      if I2.shape[0] == 1 and I1.shape[0] != 1:
        I2 = I2.repeat(3, 1, 1)

      def kernel_F(y, mu_list, sigma):
        tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).cuda()  # [81, 784]
        tmp_y = y.view(1, -1).repeat(81, 1)
        tmp_y = tmp_mu - tmp_y
        mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
        return mat_L

      mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()

      x_mu_list = mu.repeat(9).view(-1, 81)
      y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)

      mat_K = kernel_F(I1, x_mu_list, 1)
      mat_L = kernel_F(I2, y_mu_list, 1)

      H1 = ((mat_K.mm(mat_K.t())).mul(mat_L.mm(mat_L.t())) / (img_size ** 2)).cuda()
      # h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

      H2 = ((mat_K.mul(mat_L)).mm((mat_K.mul(mat_L)).t()) / img_size).cuda()
      h2 = ((mat_K.sum(1).view(-1, 1)).mul(mat_L.sum(1).view(-1, 1)) / (img_size ** 2)).cuda()
      # h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()

      H2 = 0.5 * H1 + 0.5 * H2
      tmp = H2 + 0.05 * torch.eye(len(H2)).cuda()
      # H2 = H2.double()
      # tmp = tmp.double()
      # print(tmp.device)
      # print(tmp.size())
      # start = time.clock()
      alpha = (tmp.inverse())

      alpha = alpha.mm(h2)
      # end = time.clock()
      # print('2:', end - start)
      ersmi = (2 * (alpha.t()).mm(h2) - ((alpha.t()).mm(H2)).mm(alpha) - 1).squeeze()
      # print(alpha.size())
      return ersmi

    def batch_ERSMI(I1, I2):
      batch_size = I1.shape[0]
      img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
      if I2.shape[1] == 1 and I1.shape[1] != 1:
        I2 = I2.repeat(1, 3, 1, 1)

      def kernel_F(y, mu_list, sigma):
        tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda()  # [81, 784]
        tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
        tmp_y = tmp_mu - tmp_y
        mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
        return mat_L

      mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()

      x_mu_list = mu.repeat(9).view(-1, 81)
      y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)

      mat_K = kernel_F(I1, x_mu_list, 1)
      mat_L = kernel_F(I2, y_mu_list, 1)

      H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (img_size ** 2)).cuda()
      # h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

      H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
      h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (img_size ** 2)).cuda()
      # h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()

      H2 = 0.5 * H1 + 0.5 * H2
      tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
      # H2 = H2.double()
      # tmp = tmp.double()
      # print(tmp.device)
      # print(tmp.size())
      # start = time.clock()
      alpha = (tmp.inverse())

      alpha = alpha.matmul(h2)
      # end = time.clock()
      # print('2:', end - start)
      ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
        alpha) - 1).squeeze()
      ersmi = -ersmi.mean()
      return ersmi

    # sum = 0
    # batch_size_cur = fakeI.shape[0]
    # for i in range(batch_size_cur):
    #     I1 = fakeI[i, :]
    #     I2 = realI[i, :]
    #     sum += ERSMI(I1, I2)
    # myloss = - sum / batch_size_cur
    # print(myloss.item())
    batch_loss = batch_ERSMI(fakeI, realI)
    # print("batch loss: ",batch_loss.item())
    return batch_loss


class Solver(object):
  def __init__(self, config, svhn_loader, mnist_loader):
    self.svhn_loader = svhn_loader # A domain
    self.mnist_loader = mnist_loader # B domain
    self.g21 = None
    self.d1 = None
    self.g_optimizer = None
    self.d_optimizer = None
    self.use_reconst_loss = config.use_reconst_loss
    self.use_MIR = config.use_MIR
    self.num_classes = config.num_classes
    self.beta1 = config.beta1
    self.beta2 = config.beta2
    self.g_conv_dim = config.g_conv_dim
    self.d_conv_dim = config.d_conv_dim
    self.train_iters = config.train_iters
    self.batch_size = config.batch_size
    self.lr = config.lr
    self.log_step = config.log_step
    self.sample_step = config.sample_step
    self.sample_path = config.sample_path
    self.model_path = config.model_path
    self.config = config
    self.criterionMI = MyLoss()
    self.lambda_MI = config.lambda_MI
    self.build_model()

  def build_model(self):
    """Builds a generator and a discriminator."""
    self.g21 = G21(self.config, conv_dim=self.g_conv_dim)
    self.d1 = D1(conv_dim=self.d_conv_dim)
    
    g_params = self.g21.parameters()
    d_params = self.d1.parameters()

    self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
    self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

    if torch.cuda.is_available():
      self.g21.cuda()
      self.d1.cuda()

  def merge_images(self, sources, targets, k=10):
    _, _, h, w = sources.shape
    row = int(np.sqrt(self.batch_size))
    merged = np.zeros([3, row*h, row*w*2])
    for idx, (s, t) in enumerate(zip(sources, targets)):
      i = idx // row
      j = idx % row
      merged[:, i*h:(i+1)*h, (j*2)*h:(j*2+1)*h] = s
      merged[:, i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h] = t
    return merged.transpose(1, 2, 0)

  def to_var(self, x):
    """Converts numpy to variable."""
    if torch.cuda.is_available():
      x = x.cuda()
    return Variable(x)

  def to_data(self, x):
    """Converts variable to numpy."""
    if torch.cuda.is_available():
      x = x.cpu()
    return x.data.numpy()

  def reset_grad(self):
    """Zeros the gradient buffers."""
    self.g_optimizer.zero_grad()
    self.d_optimizer.zero_grad()


  def as_np(self, data):
    return data.cpu().data.numpy()


  def test(self, svhn_test_loader, mnist_test_loader):
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    index = 0

    #fixed_svhn = self.to_var(svhn_test_iter.next()[index])
    #fixed_mnist = self.to_var(mnist_test_iter.next()[index])
    g21_path = os.path.join(self.model_path, 'g21-%d.pkl' %(40000))
    self.g21 = G21(1, conv_dim=64) 
    self.g21.load_state_dict(torch.load(g21_path))

    self.g21.cuda()

    for i in range(len(svhn_test_iter)):
      fixed_svhn = self.to_var(svhn_test_iter.next()[index])
      fake_mnist = self.g21(fixed_svhn)

      fake_mnist = self.to_data(fake_mnist)
      svhn = self.to_data(fixed_svhn)


      merged = self.merge_images(svhn, fake_mnist)
      path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(i))
      imageio.imsave(path, merged)
      print ('saved %s' %path)


  def train(self, svhn_test_loader, mnist_test_loader):
    svhn_iter = iter(self.svhn_loader)
    mnist_iter = iter(self.mnist_loader)
    iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) -1

    # fixed mnist and svhn for sampling
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    fixed_svhn = self.to_var(svhn_test_iter.next()[0])
    fixed_mnist = self.to_var(mnist_test_iter.next()[0])


    for step in range(self.train_iters+1):
      # reset data_iter for each epoch
      if (step+1) % iter_per_epoch == 0:
        mnist_iter = iter(self.mnist_loader)
        svhn_iter = iter(self.svhn_loader)

      # load svhn and mnist dataset
      svhn, s_labels = svhn_iter.next()
      svhn, s_labels = self.to_var(svhn), self.to_var(s_labels).long().squeeze()
      mnist, m_labels = mnist_iter.next() 
      mnist, m_labels = self.to_var(mnist), self.to_var(m_labels)

      #============ train D ============#

      # train with real images
      self.reset_grad()
      out = self.d1(mnist)
      d1_loss = torch.mean((out-1)**2)

      d_real_loss = d1_loss
      d_real_loss.backward()
      self.d_optimizer.step()

      # train with fake images
      self.reset_grad()
      fake_mnist = self.g21(svhn)
      out = self.d1(fake_mnist)
      d1_loss = torch.mean(out**2)

      d_fake_loss = d1_loss
      d_fake_loss.backward()
      self.d_optimizer.step()


      #============ train G ============#

      # train svhn-mnist-svhn cycle
      self.reset_grad()
      fake_mnist  = self.g21(svhn)
      out_mnist = self.d1(fake_mnist)

      gen_loss_B = torch.mean((out_mnist - 1) ** 2)
      loss_G_MI = 0
      if self.use_MIR:
        loss_G_MI = self.criterionMI(svhn, fake_mnist)
      g_loss = gen_loss_B + self.lambda_MI * loss_G_MI
      g_loss.backward()
      self.g_optimizer.step()

      # print the log info
      if (step+1) % self.log_step == 0:
        print('Step [%d/%d], d_real_loss: %.4f, d_fake_loss: %.4f, gen_loss_B: %.4f, loss_G_MI: %.4f, g_loss: %.4f'
              %(step+1, self.train_iters, d_real_loss.item(), d_fake_loss.item(), gen_loss_B.item(), loss_G_MI, g_loss.item()))


      # save the sampled images
      if (step+1) % self.sample_step == 0:
        fake_mnist = self.g21(fixed_svhn)
        
        mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
        svhn = self.to_data(fixed_svhn)           
        
        merged = self.merge_images(svhn, fake_mnist)
        path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(step+1))
        imageio.imsave(path, merged)
        print ('saved %s' %path)
    

      if (step+1) % 5000 == 0:
        # save the model parameters for each epoch
        g21_path = os.path.join(self.model_path, 'g21-%d.pkl' %(step+1))
        d1_path = os.path.join(self.model_path, 'd1-%d.pkl' %(step+1))
        torch.save(self.g21.state_dict(), g21_path)
        torch.save(self.d1.state_dict(), d1_path)

