from sklearn.manifold import TSNE
import matplotlib
matplotlib.use('Agg')
import matplotlib.pylab as plt
import numpy as np
import torch
from loader import D2GDataset, Others
import pdb
import random
from util import decode_igraph_to_NAS201_str

DATASETS = {'CIFAR-10': Others(max_img=5, dataset='cifar10'),
            'CIFAR-100': Others(max_img=5, dataset='cifar100'),
            'MNIST': Others(max_img=5, dataset='mnist'),
            'AIRCRAFT': Others(max_img=5, dataset='aircraft'),
            'MetaD2A': D2GDataset(max_img=5, mode='te')}


def save_tsne(t, time, name):
  datasets = list(DATASETS.keys())
  
  t = np.transpose(t)
  nc = len(datasets)
  C = [i for i in range(5) for _ in range(time)]

  fig = plt.figure(figsize=(6, 4))

  plt.scatter(t[0], t[1], c=C, s=30, cmap=plt.cm.get_cmap('Set3', nc), marker='o')

  cbar = plt.colorbar(ticks=range(nc), cmap=plt.get_cmap('Set3', nc))
  cbar.ax.get_yaxis().set_ticks([])
  y = [0.5, 1.25, 2, 2.75, 3.5]
  for j, lab in enumerate(datasets):
    cbar.ax.text(.5, y[j], lab, ha='center', va='center')
  cbar.ax.get_yaxis().labelpad = 15
  cbar.ax.set_ylabel('Datasets', rotation=270)

  plt.savefig("1.latent_"+name)
  plt.close()

  
  
def vis_latent_per_datasets(model):
  model.eval()
  datasets = list(DATASETS.keys())
  time = 30
  tsne = TSNE(n_components=2, perplexity=10)
  
    
  x_batch = [[DATASETS[d][0] for _ in range(time)] for d in datasets]

  d_batch = [torch.stack(i).mean(1) for i in x_batch]
  d_batch = torch.stack(d_batch).view(-1,512)
  
  t = tsne.fit_transform(d_batch)
  save_tsne(t, time, 'x.png')
  
  z_batch = []
  with torch.no_grad():
    for i, x in enumerate(x_batch):
      x = torch.stack(x).cuda()
      mu, logvar = model.encode(x)
      z = model.reparameterize(mu, logvar)
      for sz in z:
        z_batch.append(sz.cpu())

  t = tsne.fit_transform(torch.stack(z_batch))

  save_tsne(t, time, 'z.png')
  


def heatmap(data, ax=None, title=None, cbar_kw={}, cbarlabel="", **kwargs):
  if not ax:
    ax = plt.gca()

  # Plot the heatmap
  im = ax.imshow(data, **kwargs)

  # Create colorbar
  cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
  cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

  # We want to show all ticks...
  ax.set_xticks(np.arange(data.shape[1]))
  ax.set_yticks(np.arange(data.shape[0]))
  # ... and label them with the respective list entries.
  #ax.set_xticklabels(col_labels)
  #ax.set_yticklabels(row_labels)

  # Let the horizontal axes labeling appear on top.
  ax.tick_params(top=True, bottom=False,
                 labeltop=True, labelbottom=False)

  # Rotate the tick labels and set their alignment.
  plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
           rotation_mode="anchor")

  # Turn spines off and create white grid.
  for edge, spine in ax.spines.items():
    spine.set_visible(False)

  ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
  ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
  ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
  ax.tick_params(which="minor", bottom=False, left=False)
  
  #ax.set_title(title, fontsize=20)
  return im, cbar


def norm0to1(v):
  min_v = torch.min(v)
  range_v = torch.max(v) - min_v
  if range_v > 0:
    normalised = (v - min_v) / range_v
  else:
    normalised = torch.zeros(v.size())
  return normalised


def vis_latent_heatmap(model):
  model.eval()
  datasets = list(DATASETS.keys())
  x_batch = [[DATASETS[d][0],DATASETS[d][0]] for d in datasets]
  z_batch = []
  s_batch = []
  with torch.no_grad():
    for i, x in enumerate(x_batch):
      x = torch.stack(x).cuda()
      mu, logvar = model.encode(x)
      z = model.reparameterize(mu, logvar)
      g_recon = model.decode(z)
      s_batch.append(decode_igraph_to_NAS201_str(g_recon[0]))
      val = z[0].view(-1,4).cpu()
      val = norm0to1(val)
      z_batch.append(val)

  for i in range(len(z_batch)):
    print(datasets[i], s_batch[i])
    #fig, ax = plt.subplots(figsize=(4, 9))
    #im, cbar = heatmap(z_batch[i], ax=ax, title=datasets[i], cmap="RdPu")
    #fig.tight_layout()
    #plt.savefig("heatmap_{}.png".format(datasets[i]))
    #plt.close()
    


def chart():
  data= ['ENAS', 'DARTSv1', 'SETN', 'GDAS', 'OURS']
  nc = len(data)
  acc=[29.905,44.935,63.605,65.24,69.61333333]
  time=[68.7442, 2.698372222,8.538530556,7.020008333,0.636490355]
  
  colors = ['salmon', 'orange', 'steelblue','c','y','black']
  markers = ['o', 'x', '^', 's','p','.']
  
  fig, ax = plt.subplots(figsize=(7, 6))
  ax.set_facecolor('gainsboro')
  
  for i in range(len(data)):
    plt.scatter(time[i], acc[i], c=colors[i], s=40, label=data[i], marker=markers[i])
    plt.text(time[i], acc[i]+1, data[i], fontsize=10)

  plt.legend(fontsize=12, loc='upper right')  # legend position

  plt.xlabel('Search Time(GPU hours)', fontsize=14)
  plt.ylabel('Top-1 Accuracy', fontsize=14)

  plt.grid(True, axis='both', color='white', alpha=0.5, linestyle='--')
  
  plt.savefig("time_acc.png")
  plt.close()
  
def interpolation():
  task1 =  [i for i in range(610,630)]
  task2 =  [i for i in range(620,640)]
  task3 = [i for i in range(610,620)] + [i for i in range(630,640)]