import torch
import torch.nn as nn
import torchvision
import os
import pickle
import scipy.io
import numpy as np
import torch.nn.functional as F

from torch.autograd import Variable
from torch import optim
from models.mnist_to_svhn.model import G12, G21
from models.mnist_to_svhn.model import D1, D2
import pdb


class MyLoss(torch.nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, fakeI, realI):
        def ERSMI(I1, I2):  
            img_size = I1.shape[0] * I1.shape[1] * I1.shape[2]

            def kernel_F(y, mu_list, sigma):
                tmp_mu = mu_list.view(-1,1).repeat(1,img_size).cuda()#[81, 784]
                tmp_y = y.view(1,-1).repeat(81,1)
                tmp_y = tmp_mu - tmp_y
                mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma**2))
                return mat_L

            mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])

            x_mu_list = mu.repeat(9).view(-1,81)
            y_mu_list = mu.unsqueeze(0).t().repeat(1,9).view(-1,81)
            
            mat_K = kernel_F(I1, x_mu_list, 1)
            mat_L = kernel_F(I2, y_mu_list, 1)

            H1 = ((mat_K.mm(mat_K.t())).mul(mat_L.mm(mat_L.t())) / (img_size ** 2)).cuda()
            #h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

            H2 = ((mat_K.mul(mat_L)).mm((mat_K.mul(mat_L)).t()) / img_size).cuda()
            h2 = ((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).cuda()
            #h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()
            
            H2 = 0.5 * H1 + 0.5 * H2
            tmp = H2 + 0.05 * torch.eye(len(H2)).cuda()
            #H2 = H2.double()
            #tmp = tmp.double()

            alpha = (tmp.inverse()).mm(h2) 
            ersmi = (2 * (alpha.t()).mm(h2) - ((alpha.t()).mm(H2)).mm(alpha) - 1).squeeze()
            return ersmi
        
        sum = 0
        batch_size_cur = fakeI.shape[0]
        for i in range(batch_size_cur):
            I1 = fakeI[i,:]
            I2 = realI[i,:]
            sum += ERSMI(I1, I2)
        myloss = - sum/batch_size_cur
        return myloss


class Solver(object):
  def __init__(self, config, svhn_loader, mnist_loader):
    self.svhn_loader = svhn_loader # A domain
    self.mnist_loader = mnist_loader # B domain
    self.g12 = None #g12, d1, g1: input: MNIST
    self.g21 = None
    self.d1 = None
    self.d2 = None
    self.g_optimizer = None
    self.d_optimizer = None
    self.use_reconst_loss = config.use_reconst_loss
    self.num_classes = config.num_classes
    self.beta1 = config.beta1
    self.beta2 = config.beta2
    self.g_conv_dim = config.g_conv_dim
    self.d_conv_dim = config.d_conv_dim
    self.train_iters = config.train_iters
    self.batch_size = config.batch_size
    self.lr = config.lr
    self.log_step = config.log_step
    self.sample_step = config.sample_step
    self.sample_path = config.sample_path
    self.model_path = config.model_path
    self.config = config
    self.criterionMI = MyLoss()
    self.lambda_MI = config.lambda_MI
    self.build_model()

  def build_model(self):
    """Builds a generator and a discriminator."""
    self.g12 = G12(self.config, conv_dim=self.g_conv_dim)
    self.g21 = G21(self.config, conv_dim=self.g_conv_dim)
    self.d1 = D1(conv_dim=self.d_conv_dim)
    self.d2 = D2(conv_dim=self.d_conv_dim)

    g_params = list(self.g12.parameters()) + list(self.g21.parameters())
    d_params = list(self.d1.parameters()) + list(self.d2.parameters())

    self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
    self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

    if torch.cuda.is_available():
      self.g12.cuda()
      self.g21.cuda()
      self.d1.cuda()
      self.d2.cuda()

  def merge_images(self, sources, targets, k=10):
    _, _, h, w = sources.shape
    row = int(np.sqrt(self.batch_size))
    merged = np.zeros([3, row*h, row*w*2])
    for idx, (s, t) in enumerate(zip(sources, targets)):
      i = idx // row
      j = idx % row
      merged[:, i*h:(i+1)*h, (j*2)*h:(j*2+1)*h] = s
      merged[:, i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h] = t
    return merged.transpose(1, 2, 0)

  def to_var(self, x):
    """Converts numpy to variable."""
    if torch.cuda.is_available():
      x = x.cuda()
    return Variable(x)

  def to_data(self, x):
    """Converts variable to numpy."""
    if torch.cuda.is_available():
      x = x.cpu()
    return x.data.numpy()

  def reset_grad(self):
    """Zeros the gradient buffers."""
    self.g_optimizer.zero_grad()
    self.d_optimizer.zero_grad()
  

  def as_np(self, data):
    return data.cpu().data.numpy()


  def test(self, svhn_test_loader, mnist_test_loader):
    svhn_iter = iter(self.svhn_loader)
    mnist_iter = iter(self.mnist_loader)
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    index = 0

    #fixed_svhn = self.to_var(svhn_test_iter.next()[index])
    #fixed_mnist = self.to_var(mnist_test_iter.next()[index])
    g12_path = os.path.join(self.model_path, 'g12-%d.pkl' %(40000))
    g21_path = os.path.join(self.model_path, 'g21-%d.pkl' %(40000))
    self.g12 = G12(1, conv_dim=64) 
    self.g21 = G21(1, conv_dim=64) 
    self.g12.load_state_dict(torch.load(g12_path))
    self.g21.load_state_dict(torch.load(g21_path))

    self.g12.cuda()
    self.g21.cuda()

    for i in range(len(svhn_test_iter)):
      fixed_svhn = self.to_var(svhn_test_iter.next()[index])
      fake_mnist = self.g21(fixed_svhn)

      fake_mnist = self.to_data(fake_mnist)
      svhn = self.to_data(fixed_svhn)


      merged = self.merge_images(svhn, fake_mnist)
      path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(i))
      scipy.misc.imsave(path, merged)
      print ('saved %s' %path)


  def train(self, svhn_test_loader, mnist_test_loader):
    svhn_iter = iter(self.svhn_loader)
    mnist_iter = iter(self.mnist_loader)
    iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) -1

    # fixed mnist and svhn for sampling
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    fixed_svhn = self.to_var(svhn_test_iter.next()[0])
    fixed_mnist = self.to_var(mnist_test_iter.next()[0])


    for step in range(self.train_iters+1):
      # reset data_iter for each epoch
      if (step+1) % iter_per_epoch == 0:
        mnist_iter = iter(self.mnist_loader)
        svhn_iter = iter(self.svhn_loader)

      size = 32
      inv_idx = torch.arange(size-1, -1, -1).long().cuda()

      # load svhn and mnist dataset
      svhn, s_labels = svhn_iter.next()
      svhn, s_labels = self.to_var(svhn), self.to_var(s_labels).long().squeeze()
      mnist, m_labels = mnist_iter.next() 
      mnist, m_labels = self.to_var(mnist), self.to_var(m_labels)

      #============ train D ============#

      # train with real images
      self.reset_grad()
      out = self.d1(mnist)
      d1_loss = torch.mean((out-1)**2)

      out = self.d2(svhn)
      d2_loss = torch.mean((out-1)**2)

      d_mnist_loss = d1_loss
      d_svhn_loss = d2_loss
      d_real_loss = d1_loss + d2_loss
      d_real_loss.backward()
      self.d_optimizer.step()

      # train with fake images
      self.reset_grad()
      fake_svhn = self.g12(mnist)
      out = self.d2(fake_svhn)
      d2_loss = torch.mean(out**2)

      fake_mnist = self.g21(svhn)
      out = self.d1(fake_mnist)
      d1_loss = torch.mean(out**2)

      d_fake_loss = d1_loss + d2_loss
      d_fake_loss.backward()
      self.d_optimizer.step()



      #============ train G ============#

      # train mnist-svhn-mnist cycle
      self.reset_grad()
      fake_svhn = self.g12(mnist)
      out_svhn = self.d2(fake_svhn)
      reconst_mnist = self.g21(fake_svhn)

      gen_loss_A = torch.mean((out_svhn-1)**2)
      loss_G_MI = self.criterionMI(fake_B, self.real_A)
      g_loss = gen_loss_A + self.lambda_MI * loss_G_MI
      g_loss.backward()
      self.g_optimizer.step()

      # train svhn-mnist-svhn cycle
      self.reset_grad()
      fake_mnist  = self.g21(svhn)
      out_mnist = self.d1(fake_mnist)
      reconst_svhn = self.g12(fake_mnist)

      gen_loss_B = torch.mean((out_mnist - 1) ** 2)
      g_loss = gen_loss_B

      g_loss.backward()
      self.g_optimizer.step()

      # print the log info
      if (step+1) % self.log_step == 0:

          print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
              'd_fake_loss: %.4f, gen_loss_A: %.4f, gen_loss_B: %.4f'
                %(step+1, self.train_iters, d_real_loss.item(), d_mnist_loss.item(),
                  d_svhn_loss.item(), d_fake_loss.item(), gen_loss_A.item(),gen_loss_B.item()))


      # save the sampled images
      if (step+1) % self.sample_step == 0:
          fake_svhn = self.g12(fixed_mnist)
          fake_mnist = self.g21(fixed_svhn)
          
          mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
          svhn , fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)
          
          merged = self.merge_images(mnist, fake_svhn)
          path = os.path.join(self.sample_path, 'sample-%d-m-s.png' %(step+1))
          scipy.misc.imsave(path, merged)
          print ('saved %s' %path)
          
          merged = self.merge_images(svhn, fake_mnist)
          path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(step+1))
          scipy.misc.imsave(path, merged)
          print ('saved %s' %path)
      
      if (step+1) % 5000 == 0:
          # save the model parameters for each epoch
          g12_path = os.path.join(self.model_path, 'g12-%d.pkl' %(step+1))
          g21_path = os.path.join(self.model_path, 'g21-%d.pkl' %(step+1))
          d1_path = os.path.join(self.model_path, 'd1-%d.pkl' %(step+1))
          d2_path = os.path.join(self.model_path, 'd2-%d.pkl' %(step+1))
          torch.save(self.g12.state_dict(), g12_path)
          torch.save(self.g21.state_dict(), g21_path)
          torch.save(self.d1.state_dict(), d1_path)
          torch.save(self.d2.state_dict(), d2_path)

