import torch
import torch.nn as nn
import torchvision
import os
import pickle
import scipy.io
import imageio
import numpy as np
import torch.nn.functional as F

from torch.autograd import Variable
from torch import optim
from models.mnist_to_svhn.model import G12, G21
from models.mnist_to_svhn.model import D1, D2
import pdb


class MyLoss(torch.nn.Module):
  def __init__(self):
      super(MyLoss, self).__init__()

  def forward(self, fakeI, realI):
      def ERSMI(I1, I2):  
        img_size = I1.shape[0] * I1.shape[1] * I1.shape[2]
        if I2.shape[0] == 1:
          I2 = I2.repeat(3,1,1)

        def kernel_F(y, mu_list, sigma):
            tmp_mu = mu_list.view(-1,1).repeat(1,img_size).cuda()#[81, 784]
            tmp_y = y.view(1,-1).repeat(81,1)
            tmp_y = tmp_mu - tmp_y
            mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma**2))
            return mat_L

        mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])

        x_mu_list = mu.repeat(9).view(-1,81)
        y_mu_list = mu.unsqueeze(0).t().repeat(1,9).view(-1,81)
        
        mat_K = kernel_F(I1, x_mu_list, 1)
        mat_L = kernel_F(I2, y_mu_list, 1)

        H1 = (((mat_K.t()).mm(mat_K)).mul((mat_L.t()).mm(mat_L))/(img_size ** 2)).cuda()
        #H1 = ((mat_K.mm(mat_K.t())).mul(mat_L.mm(mat_L.t())) / (img_size ** 2)).cuda()
        #h1 = (mat_K.mul(mat_L)).mm(torch.ones(img_size, 1)) / img_size

        H2 = (mat_K.mul(mat_L)).t().mm(mat_L.mul(mat_K))
        h2 = ((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).cuda()
        #H2 = ((mat_K.mul(mat_L)).mm((mat_K.mul(mat_L)).t()) / img_size).cuda()
        h2 = ((mat_K.t().sum(1).view(-1,1)).mul(mat_L.t().sum(1).view(-1,1)) / (img_size ** 2)).cuda()
        #h2 = (((mat_K.sum(1).view(-1,1)).mul(mat_L.sum(1).view(-1,1)) / (img_size ** 2)).double()).cuda()
        
        #H2 = 0.5 * H1 + 0.5 * H2
        tmp = H2 + 0.05 * torch.eye(len(H2)).cuda()
        #H2 = H2.double()
        #tmp = tmp.double()

        alpha = (tmp.inverse()).mm(h2) 
        ersmi = (2 * (alpha.t()).mm(h2) - ((alpha.t()).mm(H2)).mm(alpha) - 1).squeeze()
        return ersmi
      
      sum = 0
      batch_size_cur = fakeI.shape[0]
      for i in range(batch_size_cur):
        I1 = fakeI[i,:]
        I2 = realI[i,:]
        sum += ERSMI(I1, I2)
      myloss = - sum/batch_size_cur
      return myloss


class Solver(object):
  def __init__(self, config, svhn_loader, mnist_loader):
    self.svhn_loader = svhn_loader # A domain
    self.mnist_loader = mnist_loader # B domain
    self.g21 = None
    self.d1 = None
    self.g_optimizer = None
    self.d_optimizer = None
    self.use_reconst_loss = config.use_reconst_loss
    self.num_classes = config.num_classes
    self.beta1 = config.beta1
    self.beta2 = config.beta2
    self.g_conv_dim = config.g_conv_dim
    self.d_conv_dim = config.d_conv_dim
    self.train_iters = config.train_iters
    self.batch_size = config.batch_size
    self.lr = config.lr
    self.log_step = config.log_step
    self.sample_step = config.sample_step
    self.sample_path = config.sample_path
    self.model_path = config.model_path
    self.config = config
    self.criterionMI = MyLoss()
    self.lambda_MI = config.lambda_MI
    self.build_model()

  def build_model(self):
    """Builds a generator and a discriminator."""
    self.g21 = G21(self.config, conv_dim=self.g_conv_dim)
    self.d1 = D1(conv_dim=self.d_conv_dim)
    
    g_params = self.g21.parameters()
    d_params = self.d1.parameters()

    self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
    self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

    if torch.cuda.is_available():
      self.g21.cuda()
      self.d1.cuda()

  def merge_images(self, sources, targets, k=10):
    _, _, h, w = sources.shape
    row = int(np.sqrt(self.batch_size))
    merged = np.zeros([3, row*h, row*w*2])
    for idx, (s, t) in enumerate(zip(sources, targets)):
      i = idx // row
      j = idx % row
      merged[:, i*h:(i+1)*h, (j*2)*h:(j*2+1)*h] = s
      merged[:, i*h:(i+1)*h, (j*2+1)*h:(j*2+2)*h] = t
    return merged.transpose(1, 2, 0)

  def to_var(self, x):
    """Converts numpy to variable."""
    if torch.cuda.is_available():
      x = x.cuda()
    return Variable(x)

  def to_data(self, x):
    """Converts variable to numpy."""
    if torch.cuda.is_available():
      x = x.cpu()
    return x.data.numpy()

  def reset_grad(self):
    """Zeros the gradient buffers."""
    self.g_optimizer.zero_grad()
    self.d_optimizer.zero_grad()


  def as_np(self, data):
    return data.cpu().data.numpy()


  def test(self, svhn_test_loader, mnist_test_loader):
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    index = 0

    #fixed_svhn = self.to_var(svhn_test_iter.next()[index])
    #fixed_mnist = self.to_var(mnist_test_iter.next()[index])
    g21_path = os.path.join(self.model_path, 'g21-%d.pkl' %(40000))
    self.g21 = G21(1, conv_dim=64) 
    self.g21.load_state_dict(torch.load(g21_path))

    self.g21.cuda()

    for i in range(len(svhn_test_iter)):
      fixed_svhn = self.to_var(svhn_test_iter.next()[index])
      fake_mnist = self.g21(fixed_svhn)

      fake_mnist = self.to_data(fake_mnist)
      svhn = self.to_data(fixed_svhn)


      merged = self.merge_images(svhn, fake_mnist)
      path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(i))
      imageio.imsave(path, merged)
      print ('saved %s' %path)


  def train(self, svhn_test_loader, mnist_test_loader):
    svhn_iter = iter(self.svhn_loader)
    mnist_iter = iter(self.mnist_loader)
    iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) -1

    # fixed mnist and svhn for sampling
    svhn_test_iter = iter(svhn_test_loader)
    mnist_test_iter = iter(mnist_test_loader)
    fixed_svhn = self.to_var(svhn_test_iter.next()[0])
    fixed_mnist = self.to_var(mnist_test_iter.next()[0])


    for step in range(self.train_iters+1):
      # reset data_iter for each epoch
      if (step+1) % iter_per_epoch == 0:
        mnist_iter = iter(self.mnist_loader)
        svhn_iter = iter(self.svhn_loader)

      # load svhn and mnist dataset
      svhn, s_labels = svhn_iter.next()
      svhn, s_labels = self.to_var(svhn), self.to_var(s_labels).long().squeeze()
      mnist, m_labels = mnist_iter.next() 
      mnist, m_labels = self.to_var(mnist), self.to_var(m_labels)

      #============ train D ============#

      # train with real images
      self.reset_grad()
      out = self.d1(mnist)
      d1_loss = torch.mean((out-1)**2)

      d_real_loss = d1_loss
      d_real_loss.backward()
      self.d_optimizer.step()

      # train with fake images
      self.reset_grad()
      fake_mnist = self.g21(svhn)
      out = self.d1(fake_mnist)
      d1_loss = torch.mean(out**2)

      d_fake_loss = d1_loss
      d_fake_loss.backward()
      self.d_optimizer.step()


      #============ train G ============#

      # train svhn-mnist-svhn cycle
      self.reset_grad()
      fake_mnist  = self.g21(svhn)
      out_mnist = self.d1(fake_mnist)

      gen_loss_B = torch.mean((out_mnist - 1) ** 2)

      loss_G_MI = self.criterionMI(svhn, fake_mnist)
      g_loss = gen_loss_B + self.lambda_MI * loss_G_MI
      g_loss.backward()
      self.g_optimizer.step()

      # print the log info
      if (step+1) % self.log_step == 0:
        print('Step [%d/%d], d_real_loss: %.4f, d_fake_loss: %.4f, gen_loss_B: %.4f, loss_G_MI: %.4f, g_loss: %.4f'
              %(step+1, self.train_iters, d_real_loss.item(), d_fake_loss.item(), gen_loss_B.item(), loss_G_MI.item(), g_loss.item()))


      # save the sampled images
      if (step+1) % self.sample_step == 0:
        fake_mnist = self.g21(fixed_svhn)
        
        mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
        svhn = self.to_data(fixed_svhn)           
        
        merged = self.merge_images(svhn, fake_mnist)
        path = os.path.join(self.sample_path, 'sample-%d-s-m.png' %(step+1))
        imageio.imsave(path, merged)
        print ('saved %s' %path)
    

      if (step+1) % 5000 == 0:
        # save the model parameters for each epoch
        g21_path = os.path.join(self.model_path, 'g21-%d.pkl' %(step+1))
        d1_path = os.path.join(self.model_path, 'd1-%d.pkl' %(step+1))
        torch.save(self.g21.state_dict(), g21_path)
        torch.save(self.d1.state_dict(), d1_path)

