import torch
import torch.nn as nn
import torchvision
import os
import pickle
import scipy.io
import imageio
import numpy as np
import torch.nn.functional as F

from torch.autograd import Variable
from torch import optim
from model import G12, G21
from model import D1, D2
import pdb


class MyLoss(torch.nn.Module):
  def __init__(self):
    super(MyLoss, self).__init__()

  def forward(self, fakeI, realI):
    def ERSMI(I1, I2):
      img_size = I1.shape[0] * I1.shape[1] * I1.shape[2]
      if I2.shape[0] == 1 and I1.shape[0] != 1:
        I2 = I2.repeat(3, 1, 1)

      def kernel_F(y, mu_list, sigma):
        tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).cuda()  # [81, 784]
        tmp_y = y.view(1, -1).repeat(81, 1)
        tmp_y = tmp_mu - tmp_y
        mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
        return mat_L

      mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()

      x_mu_list = mu.repeat(9).view(-1, 81)
      y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)

      mat_K = kernel_F(I1, x_mu_list, 1)
      mat_L = kernel_F(I2, y_mu_list, 1)

      H1 = ((mat_K.mm(mat_K.t())).mul(mat_L.mm(mat_L.t())) / (img_size ** 2)).cuda()
      # h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

      H2 = ((mat_K.mul(mat_L)).mm((mat_K.mul(mat_L)).t()) / img_size).cuda()
      h2 = ((mat_K.sum(1).view(-1, 1)).mul(mat_L.sum(1).view(-1, 1)) / (img_size ** 2)).cuda()
      # h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()

      H2 = 0.5 * H1 + 0.5 * H2
      tmp = H2 + 0.05 * torch.eye(len(H2)).cuda()
      # H2 = H2.double()
      # tmp = tmp.double()
      # print(tmp.device)
      # print(tmp.size())
      # start = time.clock()
      alpha = (tmp.inverse())

      alpha = alpha.mm(h2)
      # end = time.clock()
      # print('2:', end - start)
      ersmi = (2 * (alpha.t()).mm(h2) - ((alpha.t()).mm(H2)).mm(alpha) - 1).squeeze()
      # print(alpha.size())
      return ersmi

    def batch_ERSMI(I1, I2):
      batch_size = I1.shape[0]
      img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
      if I2.shape[1] == 1 and I1.shape[1] != 1:
        I2 = I2.repeat(1, 3, 1, 1)

      def kernel_F(y, mu_list, sigma):
        tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda()  # [81, 784]
        tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
        tmp_y = tmp_mu - tmp_y
        mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
        return mat_L

      mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()

      x_mu_list = mu.repeat(9).view(-1, 81)
      y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)

      mat_K = kernel_F(I1, x_mu_list, 1)
      mat_L = kernel_F(I2, y_mu_list, 1)

      H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (img_size ** 2)).cuda()
      # h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

      H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
      h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (img_size ** 2)).cuda()
      # h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()

      H2 = 0.5 * H1 + 0.5 * H2
      tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
      # H2 = H2.double()
      # tmp = tmp.double()
      # print(tmp.device)
      # print(tmp.size())
      # start = time.clock()
      alpha = (tmp.inverse())

      alpha = alpha.matmul(h2)
      # end = time.clock()
      # print('2:', end - start)
      ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
        alpha) - 1).squeeze()
      ersmi = -ersmi.mean()
      return ersmi

    # sum = 0
    # batch_size_cur = fakeI.shape[0]
    # for i in range(batch_size_cur):
    #     I1 = fakeI[i, :]
    #     I2 = realI[i, :]
    #     sum += ERSMI(I1, I2)
    # myloss = - sum / batch_size_cur
    # print(myloss.item())
    batch_loss = batch_ERSMI(fakeI, realI)
    # print("batch loss: ",batch_loss.item())
    return batch_loss


class Solver(object):
  def __init__(self, config, svhn_loader, mnist_loader):
    self.mnist_loader = mnist_loader # 1 domain
    self.svhn_loader = svhn_loader # 2 domain
    self.g12 = None #g12, d1, g1: input: MNIST
    self.d2 = None
    self.g_optimizer = None
    self.d_optimizer = None
    self.use_reconst_loss = config.use_reconst_loss
    self.use_MIR = config.use_MIR
    self.num_classes = config.num_classes
    self.beta1 = config.beta1
    self.beta2 = config.beta2
    self.g_conv_dim = config.g_conv_dim
    self.d_conv_dim = config.d_conv_dim
    self.train_iters = config.train_iters
    self.batch_size = config.batch_size
    self.lr = config.lr
    self.log_step = config.log_step
    self.sample_step = config.sample_step
    self.sample_path = config.sample_path
    self.model_path = config.model_path
    self.config = config
    self.criterionMI = MyLoss()
    self.lambda_MI = config.lambda_MI
    self.build_model()

  def build_model(self):
    """Builds a generator and a discriminator."""
    self.g12 = G12(self.config, conv_dim=self.g_conv_dim)
    self.d2 = D2(conf=self.config,conv_dim=self.d_conv_dim)

    g_params = self.g12.parameters()
    d_params = self.d2.parameters()

    self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
    self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

    if torch.cuda.is_available():
      self.g12.cuda()
      self.d2.cuda()

  def merge_images(self, sources, targets, k=10):
    _, _, h, w = sources.shape
    row = int(np.sqrt(self.batch_size))
    merged = np.zeros([3, row*h, row*w*2])
    for idx, (s, t) in enumerate(zip(sources, targets)):
      i = idx // row
      j = idx % row
      merged[:, i*h:(i+1)*h, (j*2)*h:(j*2+1)*h] = s
      merged[:, i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h] = t
    return merged.transpose(1, 2, 0)

  def to_var(self, x):
    """Converts numpy to variable."""
    if torch.cuda.is_available():
      x = x.cuda()
    return Variable(x)

  def to_data(self, x):
    """Converts variable to numpy."""
    if torch.cuda.is_available():
      x = x.cpu()
    return x.data.numpy()

  def reset_grad(self):
    """Zeros the gradient buffers."""
    self.g_optimizer.zero_grad()
    self.d_optimizer.zero_grad()
  

  def as_np(self, data):
    return data.cpu().data.numpy()


  def test(self, svhn_test_loader, mnist_test_loader):
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    index = 0

    #fixed_svhn = self.to_var(svhn_test_iter.next()[index])
    #fixed_mnist = self.to_var(mnist_test_iter.next()[index])
    g12_path = os.path.join(self.model_path, 'g12-%d.pkl' %(40000))
    self.g12 = G12(1, conv_dim=64) 
    self.g12.load_state_dict(torch.load(g12_path))

    self.g12.cuda()

    for i in range(len(svhn_test_iter)):
      fixed_mnist = self.to_var(mnist_test_iter.next()[index])
      fake_svhn = self.g12(fixed_mnist)

      fake_svhn = self.to_data(fake_svhn)
      mnist = self.to_data(fixed_mnist)


      merged = self.merge_images(mnist, fake_svhn)
      path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(i))
      imageio.imsave(path, merged)
      print ('saved %s' %path)


  def train(self, svhn_test_loader, mnist_test_loader):
    svhn_iter = iter(self.svhn_loader)
    mnist_iter = iter(self.mnist_loader)
    iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) -1

    # fixed mnist and svhn for sampling
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    fixed_svhn = self.to_var(svhn_test_iter.next()[0])
    fixed_mnist = self.to_var(mnist_test_iter.next()[0])


    for step in range(self.train_iters+1):
      # reset data_iter for each epoch
      if (step+1) % iter_per_epoch == 0:
        mnist_iter = iter(self.mnist_loader)
        svhn_iter = iter(self.svhn_loader)

      # load svhn and mnist dataset
      svhn, s_labels = svhn_iter.next()
      svhn, s_labels = self.to_var(svhn), self.to_var(s_labels).long().squeeze()
      mnist, m_labels = mnist_iter.next() 
      mnist, m_labels = self.to_var(mnist), self.to_var(m_labels)

      #============ train D ============#

      # train with real images
      self.reset_grad()
      out = self.d2(svhn)
      d2_loss = torch.mean((out-1)**2)

      d_real_loss = d2_loss
      d_real_loss.backward()
      self.d_optimizer.step()

      # train with fake images
      self.reset_grad()
      fake_svhn = self.g12(mnist)
      out = self.d2(fake_svhn)
      d2_loss = torch.mean(out**2)

      d_fake_loss = d2_loss
      d_fake_loss.backward()
      self.d_optimizer.step()


      #============ train G ============#

      # train mnist-svhn-mnist cycle
      self.reset_grad()
      fake_svhn = self.g12(mnist)
      out_svhn = self.d2(fake_svhn)

      gen_loss_A = torch.mean((out_svhn-1)**2)
      loss_G_MI = 0
      if self.use_MIR:
        loss_G_MI = self.criterionMI(fake_svhn, mnist)
      g_loss = gen_loss_A + self.lambda_MI * loss_G_MI
      g_loss.backward()
      self.g_optimizer.step()

      # print the log info
      if (step+1) % self.log_step == 0:
        print('Step [%d/%d], d_real_loss: %.4f, d_fake_loss: %.4f, gen_loss_A: %.4f, loss_G_MI: %.4f, g_loss: %.4f'
              %(step+1, self.train_iters, d_real_loss.item(), d_fake_loss.item(), gen_loss_A.item(), loss_G_MI, g_loss.item()))


      # save the sampled images
      if (step+1) % self.sample_step == 0:
        fake_svhn = self.g12(fixed_mnist)
        
        mnist = self.to_data(fixed_mnist)
        svhn , fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)
        
        merged = self.merge_images(mnist, fake_svhn)
        path = os.path.join(self.sample_path, 'sample-%d-m-s.png' %(step+1))
        imageio.imsave(path, merged)
        print ('saved %s' %path)
                 
      
      if (step+1) % 5000 == 0:
        # save the model parameters for each epoch
        g12_path = os.path.join(self.model_path, 'g12-%d.pkl' %(step+1))
        d2_path = os.path.join(self.model_path, 'd2-%d.pkl' %(step+1))
        torch.save(self.g12.state_dict(), g12_path)
        torch.save(self.d2.state_dict(), d2_path)

