#!/usr/bin/env python3

import argparse
import time
import collections
import functools
import os
import pickle
from pathlib import Path

import matplotlib
from matplotlib.transforms import Transform
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
import torchvision

import scipy
from scipy.cluster import hierarchy
import numpy as np
import numpy.random as npr
from omegaconf import OmegaConf

from uimnet import datasets
from uimnet import workers
from uimnet import utils
from uimnet import dispatch

from uimnet import __DEBUG__


def parse_arguments():

  parser = argparse.ArgumentParser(
    description='runs clustering workflow')
  parser.add_argument('-c', '--cfg_path', type=str, required=True, help='Path to YAML config.')
  parser.add_argument('-s', '--sweep_dir', type=str, default=None, help='Path to sweep dir. Overrides YAML cfg.')

  return parser.parse_args()

def get_encoder(name, num_classes):

    # init network, conveniently decomposed in featurizer and classifier
    featurizer = torchvision.models.__dict__[name](
        num_classes=num_classes, pretrained=True)
    featurizer.fc = utils.Identity()
    return featurizer


@utils.checkpoint()
def extract_embeddings(cfg, encoder, datanode):
  query = dispatch.submit(cfg=cfg,
                          output_dir=cfg.sweep_dir,
                          worker=workers.EmbeddingsExtractor(),
                          encoder=encoder,
                          datanode=datanode,
                          )

  # with open('./query.pkl', 'rb') as fp:
  #   query = pickle.load(fp)
  import ipdb; ipdb.set_trace()
  results = utils.map_dict(query['data'][0], op=functools.partial(torch.cat, dim=0))
  return results


@utils.checkpoint()
def compute_class_conditional_statistics(results):

  out = collections.defaultdict(list)
  targets = torch.unique(results['targets'])
  for target in targets:

    target_mask = results['targets'] == target
    class_embeddings = results['embeddings'][target_mask]
    out[int(target)] += [dict(embeddings=class_embeddings,
                              mean_embeddings=class_embeddings.mean(dim=0, keepdims=True),
                              indices=results['indices'][target_mask])]

  _cat = functools.partial(torch.cat, dim=0)
  out = {k: utils.map_dict(v, op=_cat)  for k, v in out.items()}
  return out



@utils.checkpoint()
def partition_data(cond_stats, metric, method, seed):
  # I - computing conditional embeddings
  mean_embeddings = torch.cat([s['mean_embeddings'] for s in cond_stats.values()], dim=0)
  embeddings = mean_embeddings.detach().cpu().numpy()
  rng = npr.default_rng(seed=seed)

  dists = scipy.spatial.distance.pdist(
      embeddings, metric=metric)  # In triangular condensed form
  linked = hierarchy.linkage(dists, 'ward')

  num_clusters = 2
  clusters_idx = hierarchy.fcluster(
      linked, criterion='maxclust', t=num_clusters)
  assert len(np.unique(clusters_idx)) == num_clusters

  # WLOG cluster 1 is to hold the in-domain classes
  # TODO(XXX): in domain class chosen to be the smallest one
  # TODO(XXX): move equalize paritions here
  easy_ood_idx = np.where(clusters_idx == 2)[0]
  # in-domain + difficult out-of-domain idx
  in_difficult_ood_idx = np.where(clusters_idx == 1)[0]
  assert len(easy_ood_idx) + len(in_difficult_ood_idx) == len(clusters_idx)
  assert np.unique(np.concatenate(
    [easy_ood_idx, in_difficult_ood_idx], axis=0)).shape[0] == len(clusters_idx)

  if method == 'random':
      # Randomly separating ins and outs
      in_idx, difficult_ood_idx = np.split(in_difficult_ood_idx[rng.permutation(
          len(in_difficult_ood_idx))], indices_or_sections=2, axis=0)
      assert len(in_idx) == len(difficult_ood_idx)

  elif method == 'sorted_distances':

      in_difficult_ood_dist = scipy.spatial.distance.pdist(
          embeddings[in_difficult_ood_idx], metric=metric)

      n = in_difficult_ood_idx.shape[0]
      indices = [(i, j) for i in range(n) for j in range(i + 1, n)]
      indices = np.array(list(map(lambda els: (
          in_difficult_ood_idx[els[0]], in_difficult_ood_idx[els[1]]), indices))).astype('int64')

      # Sorting the distances
      sorted_dist_idx = np.argsort(in_difficult_ood_dist)
      sorted_indices = indices[sorted_dist_idx]

      ins = set()
      difficults = set()

      _in, _out = sorted_indices[0]
      ins.add(_in)
      difficults.add(_out)

      for (_in, _out) in sorted_indices[1:]:

          if _in not in difficults:
              ins.add(_in)

          if _out not in ins:
              difficults.add(_out)

      in_idx = np.array(list(ins)).astype('int64')
      difficult_ood_idx = np.array(list(difficults)).astype('int64')
      assert len(ins) + len(difficults) == len(in_difficult_ood_idx)

  else:
      err_msg = f'Unrecognise selection method {method}'
      raise ValueError(err_msg)

  targets_partitions = {'in': in_idx,
                        'easy': easy_ood_idx,
                        'difficult': difficult_ood_idx}

  targets_partitions = {k: np.array(list(v)).astype(
      'int64').tolist() for k, v in targets_partitions.items()}
  partitions= {}
  for name, target_partition in targets_partitions.items():
    res = utils.map_dict([cond_stats[y] for y in target_partition],
                          op=functools.partial(torch.cat, dim=0))
    res.update(targets=target_partition)
    partitions[name] = res

  out = dict(partitions=partitions,
            linked=linked
              )
  return out


def get_partition_color(label, labels_partition):
  if label in labels_partition['in']['targets']:
    return 'g'

  elif label in labels_partition['easy']['targets']:
    return 'b'

  elif label in labels_partition['difficult']['targets']:
    return 'r'
  else:
    return 'm'


@utils.checkpoint()
def vizualize_partitions(dataset, partitions, linked):
  targets = list(dataset.available_targets)
  classes = [datasets.IMAGENET_CLASSES[t] for t in targets]
  class_to_target = {v: k for k, v in enumerate(classes)}

  plt.clf()
  fig = plt.figure(figsize=(10, 100))
  P=hierarchy.dendrogram(linked,
                      labels=classes,
                      leaf_font_size=8.,
                      orientation='left')
  ax = plt.gca()
  y_labels = ax.get_ymajorticklabels()
  for ylabel in y_labels:
    class_ = ylabel.get_text()
    target = class_to_target[class_]
    ylabel.set_color(get_partition_color(target, partitions))
  return dict(dendrogram=fig)


@utils.timeit
def run_clustering(cfg):
  root=os.getenv('DATASETS_ROOT', cfg.dataset.root)
  dataset = datasets.ImageNat(
    root=root,
    transform=datasets.TRANSFORMS['eval'],
    split=cfg.dataset.split,
  )

  datanode = datasets.SimpleDataNode(dataset,
                                     transforms=datasets.TRANSFORMS,
                                     seed=cfg.dataset.seed)

  encoder = get_encoder(name=cfg.clustering.encoder,
                        num_classes=datanode.dataset.num_classes)

  # I - Extract embeddings
  utils.message('Extracting embeddings.')
  results = extract_embeddings(cfg, encoder=encoder, datanode=datanode)

  # II - Extract Conditional statistics
  cond_stats = compute_class_conditional_statistics(results)

  # III - Partition data
  results = partition_data(cond_stats=cond_stats,
                        metric=cfg.clustering.metric,
                        method=cfg.clustering.method,
                        seed=cfg.clustering.seed)

  # IV - Make graphs
  figs = vizualize_partitions(dataset,
                              partitions=results['partitions'],
                              linked=results['linked'],
                              )

  output =  dict(
    cfg=cfg,
    figs=figs,
    # TODO(XXX): make partitions text files
    partitions=results['partitions'],
    report=None,
  )

  return output

def main(cfg_path, sweep_dir):

  # Load config
  cfg = utils.load_cfg(cfg_path, sweep_dir)
  output = run_clustering(cfg)

  sweep_path = Path(sweep_dir).absolute()
  figs_path = sweep_path / 'figs'
  figs_path.mkdir(parents=True, exist_ok=True)
  for name, fig in output['figs'].items():
    fig.savefig(figs_path / f'{name}.png')

  with open(sweep_path / 'clustering_results.pkl', 'wb') as fp:
    pickle.dump(output, fp, protocol=pickle.HIGHEST_PROTOCOL)

  return output



if __name__ == '__main__':

  args = parse_arguments()
  output = main(args.cfg_path, args.sweep_dir)
