#!/usr/bin/env python3
import numpy as np
import torch
import time
import collections
import functools
import torch.nn.functional as F
import submitit

from uimnet import datasets
from uimnet import utils
from uimnet import workers

from uimnet import __DEBUG__

class Calibrator(workers.Worker):

  OPTIM_KWARGS=[
    dict(lr=1, max_iter=20, line_search_fn=None),
    dict(lr=0.01, max_iter=50, line_search_fn=None),
    dict(lr=0.001, max_iter=200, line_search_fn=None),
  ]

  TAUS0 = [0.5, 1., 1.5]
  NEPOCHS = 20

  def __init__(self):

    super(Calibrator, self).__init__()
    """
    Calibrator
    """
    # INITIAL WORKER STATE
    #
    # Kwargs-vals. could use inspect but this saner.
    self.cfg = None
    self.Algorithm = None
    self.dataset = None

    self.algorithm = None
    self.datanode = None


    return

  def checkpoint(self, *args, **kwargs):

    if utils.is_not_distributed_or_is_rank0():
      new_callable = Calibrator()
      utils.write_trace('calibration.interrupted', dir_=str(self.train_cfg.output_dir))
      return submitit.helpers.DelayedSubmission(new_callable,
                                                cfg=self.train_cfg,
                                                Algorithm=self.Algorithm,
                                                dataset=self.dataset
                                                )

  def __call__(self, calibration_cfg, train_cfg, Algorithm, dataset):

    elapsed_time = time.time()
    self.train_cfg = train_cfg
    self.calibration_cfg = calibration_cfg
    self.Algorithm = Algorithm
    self.dataset = dataset

    self.setup(calibration_cfg)  # Setup modifies cfg. It needs a state on the worker.
    utils.message(calibration_cfg)

    if utils.is_not_distributed_or_is_rank0():
      if not utils.trace_exists('train.done', dir_=train_cfg.output_dir):
        err_msg = f'Training not finished'
        raise RuntimeError(err_msg)
      utils.write_trace('calibration.running', dir_=train_cfg.output_dir)

    self.datanode = datasets.SplitDataNode(
      dataset=dataset,
      transforms=datasets.TRANSFORMS,
      splits_props=train_cfg.dataset.splits_props,
      seed=train_cfg.dataset.seed)
    num_classes = self.datanode.splits['train'].num_classes

    self.algorithm = Algorithm(num_classes=num_classes,
                               arch=train_cfg.algorithm.arch,
                               device=calibration_cfg.experiment.device,
                               use_mixed_precision=train_cfg.algorithm.use_mixed_precision,
                               seed=train_cfg.algorithm.seed,
                               sn=train_cfg.algorithm.sn,
                               sn_coef=train_cfg.algorithm.sn_coef,
                               sn_bn=train_cfg.algorithm.sn_bn
                               )

    self.algorithm.initialize()
    utils.message(self.algorithm)

    adapt_state = train_cfg.experiment.distributed and not utils.is_distributed()
    self.algorithm.load_state(train_cfg.output_dir, map_location=calibration_cfg.experiment.device,
                              adapt_state=adapt_state)


    # Preparing for calibration
    self.datanode.eval()
    self.algorithm.eval()
    eval_loader = self.datanode.get_loader('eval',
                                            batch_size=calibration_cfg.dataset.batch_size,
                                            shuffle=False,
                                            pin_memory=True if 'cuda' in calibration_cfg.experiment.device else False,
                                            num_workers=calibration_cfg.experiment.num_workers)

    utils.message('Collecting logits and targets')
    collected = self.algorithm.collect_logits_and_targets(eval_loader)
    self.algorithm.cpu()

    utils.message(f'Reinitializing temperature')
    self.algorithm.temperature.reinitialize_temperature()

    tau_initial = self.algorithm.temperature.tau.data.clone()

    results = []
    for tau0 in self.TAUS0:
      with torch.no_grad():
        _tau_initial = tau_initial.data.clone()
        _tau_initial.fill_(tau0)
      results += [self.calibrate(collected, _tau_initial.data.clone(), optim_kwargs)
                  for optim_kwargs in self.OPTIM_KWARGS]
    results = sorted([el for el in results if not np.isnan(el[0])], key=lambda el: el[0])

    utils.message('All calibration results')
    utils.message(results)
    final_loss, final_tau = results[0]
    utils.message(f'Final loss={final_loss}, at tau = {final_tau}')

    with torch.no_grad():
      self.algorithm.temperature.tau.fill_(final_tau)

    utils.message('Finalizing calibration')
    # self.algorithm.cuda()
    utils.message('serializing model.')
    if utils.is_not_distributed_or_is_rank0():
      self.algorithm.save_state(train_cfg.output_dir)
    self.finalize(train_cfg)

    return {'data': dict(tau_star=self.algorithm.temperature.tau.detach().cpu(), tau_initial=tau_initial.detach().cpu()),
            'calibration_cfg': calibration_cfg,
            'train_cfg': train_cfg,
            'elapsed_time': time.time() - elapsed_time,
            'status': 'done'}

  def calibrate(self, collected, tau, optim_kwargs):

    def _loss(logits, tau, y):
      return F.cross_entropy(logits / tau, y)

    best_loss = _loss(collected['logits'], tau, collected['y'])
    best_tau = tau.data.clone()

    utils.message(f'l({best_tau}) = {best_loss}')

    tau.requires_grad = True
    optimizer = torch.optim.LBFGS([tau], **optim_kwargs)
    for i in range(self.NEPOCHS):

      def _closure():
        optimizer.zero_grad()
        loss_value = _loss(collected['logits'], tau, collected['y'])
        loss_value.backward()
        return loss_value

      loss = optimizer.step(_closure)

      if loss < best_loss:
        best_loss = loss.data.clone()
        best_tau = tau.data.clone()

      utils.message(f'Epoch {i + 1}, loss({tau}) = {loss}')

    tau.requires_grad = False
    utils.message(f'l(tau_*) = {best_loss}')
    return best_loss.item(), best_tau.item()

  def finalize(self, cfg):

    if utils.is_not_distributed_or_is_rank0():
      utils.write_trace('calibration.done', dir_=cfg.output_dir)
    utils.message(f'Calibration completed.')

    return


if __name__ == "__main__":
    pass
