import json
from this import d
import yaml
import itertools
import os

from datetime import datetime

import numpy as np
import jax.numpy as jnp
from jax import random
from jax.ops import index_update

from tqdm import tqdm

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.nn import KLDivLoss
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.autograd import Variable

import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T

import logging

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule

import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

from utils import data
from utils import plotting
from utils import utils

import pdb

import seml
from sacred import Experiment


ex = Experiment()

# Uncomment to run LOCALLY
ex.add_config("configs/iv_general.yaml")


logger = logging.getLogger("mylogger")
logger.setLevel("INFO")

# # attach it to the experiment
ex.logger = logger



class Objective:
  """
  A class that represents the objective.
  We'll make it callable via the __call__ function.
  """

  def __init__(self, dim_x, store_data, sample_size=500):
    self.distr = MultivariateNormal(torch.zeros(dim_x),
                                    torch.eye(dim_x))
    self.sample_size = sample_size
    self.store_data = store_data
    self.dim_x = dim_x

  def __call__(self, y_transform, xstar):
    gen_xstar = torch.zeros((self.sample_size, self.dim_x)) + xstar

    base_xy = self.distr.rsample(sample_shape=torch.Size([self.sample_size]))  
    gen_y = y_transform.condition(gen_xstar)(base_xy[:, :1])
    return torch.mean(gen_y)

class Constraints:
  """
  A class that represents the constraints.
  We'll make it callable via the __call__ function.
  """

  def __init__(self, data, slack, lagrangian_strategy, num_samples_mmd, save_rhs_every_step, store_data):
    """Initialize the Constraints class.

    Params:
      lhs: We can precompute the entire lhs and pass it in as
          a single 1D Tensor.
      n_inferred: We pre-compute the n = h^{-1}_{z_i}(x_i)
          for all datapoints and pass it in here.
      psi_vector: A callable that evaluates the basis functions
          at its argument.
      dim_theta: Dimension of theta.
    """
    self.n_constraints = 1
    self.slack = slack
    self.lagrangian_strategy = lagrangian_strategy
    self.save_rhs_every_step = save_rhs_every_step
    self.store_data = store_data
    self.data = data
    self.dim_x = data['x'].shape[1]
    self.num_samples_mmd = num_samples_mmd

  def __call__(self, transforms, results, opt_stp):
    """Constraints at current parameters and for given indices."""
    true_data = torch.cat((self.data['z'], self.data['x'], self.data['y'].unsqueeze(1)), dim=1)
    idx_keep = utils.get_indices(self.data, self.num_samples_mmd)
    true_data = true_data[idx_keep, :]
    x_transform, y_transform = transforms
    gen_x, gen_y, gen_z = utils.generative_model_IV(self.data['z'][idx_keep, :], self.dim_x, \
                          x_transform, y_transform, self.num_samples_mmd)
    gen_data = torch.cat((gen_z, gen_x, gen_y), dim=1)
    
    constr = utils.MMD(gen_data, true_data, 'rbf')


    if(self.lagrangian_strategy == 'inequality'):
        constrs = self.slack - torch.abs(constr)
    elif(self.lagrangian_strategy == 'equality'):
        constrs = constr

    # if(opt_stp == 0):
    #   results['rhs'].append(rhs.clone().detach().numpy())
    # if(self.store_data):
    #   if(self.save_rhs_every_step == True):
    #     results['all_rhs'].append(rhs.clone().detach().numpy())
    #   lhsmrhs = torch.abs(self.lhs - rhs.clone().detach())
    #   satis_status = lhsmrhs <= self.slack
    #   results["lhsmrhs_mean"].append(lhsmrhs.mean().item())
    #   results["lhsmrhs_min"].append(lhsmrhs.min().item())
    #   results["lhsmrhs_max"].append(lhsmrhs.max().item())
    #   results["lhsmrhs_norm"].append(torch.linalg.norm(lhsmrhs).item())
    #   results["satis_frac"].append(satis_status.sum().item()/(2*self.n_constraints))
    #   results["lhs"] = self.lhs
    return constrs


class AugmentedLagrangian:
  """
  Augmented Lagrangian method with equality constraints
  for an objective that depends on data and with the option
  to subsample constraints.
  """

  def __init__(self,
               data,
               objective,
               constraints,
               log_dir,
               tau_init=10,
               alpha=100):
    """Initialize the augmented Langrangian method.

    Params:
      data: Contains the data in some format that is accepted by the
          `objective` and `constraints` functions. For example, a
          dictionary with keys 'x', 'y', 'z', 'xstar' and tensors as values.
      objective: A callable taking the two positional arguments `data`,
          and `model` as well as potentially more keyword arguments and
          returning a single real number.
      constraints: A callable taking the two positional arguments `data`,
          `model`, and `indices` as well as potentially keyword arguments and
          returning a 1D tensor of values. These stand for the constraints
          that should be all equal to zero. Lenght must be at least
          `n_constraints`.
      model: Pytorch lightning model for theta.
      n_constraints: An integer indicating how many constraints to add to
          the Lagrangian in each round.
      tau_init: The initial value (real) of tau for the 'temperature' parameter
          of the augmented Lagrangian.
      alpha: The factor by which to increase `tau` when the constraints
          have not been reduced sufficiently.
    """
    self.data = data
    self.objective = objective
    self.constraints = constraints
    self.tau_init = tau_init
    self.alpha = alpha
    self.log_dir = log_dir

  @ex.capture
  def optimize(self,
               bound,
               xstar,
               xstar_name,
               _config,
               _log,
               _run,
               n_rounds=100,
               opt_steps=50,
               lr=0.01):
    """Run the augmented Lagrangian optimization."""
    # Initializing parameters

    out_dir = os.path.join(self.log_dir, f"{bound}-xstar_{xstar_name}")
    if _config['store_data']:
      _log.info(f"Current run output directory: {out_dir}...")
    if not os.path.exists(out_dir):
      os.makedirs(out_dir)

    _log.info(f"Evaluate at xstar={xstar}...")
    _log.info(f"Evaluate {bound} bound...")

    _log.info(f"Initialize model and weights...")
    torch.manual_seed(int(_config['seed_method']))
    x_transform = T.conditional_affine_coupling(_config['dim_x'], context_dim=_config['dim_z'])
    y_transform = T.conditional_affine_coupling(1, context_dim=_config['dim_x'])

    _log.info(f"Initialize dictionary for results...")
    results = {
        "rhs": [],
        "all_rhs": [],
        "objective": [],
        "objective_every_step": [],
        "lhsmrhs_mean": [],
        "lhsmrhs_max": [],
        "lhsmrhs_min": [],
        "lhsmrhs_norm": [],
        "satis_frac": []
    }
    _log.info(f"Store xstar value for easier cumulative analysis...")
    results['xstar'] = xstar
    results['xstar_name'] = xstar_name
    results['bound'] = bound
    
    sign = 1 if bound == "lower" else -1
    tau = self.tau_init
    alpha = self.alpha
    eta = 1 / 4 * tau ** 0.1
    lmbda = torch.tensor(1).float()

    # Setup optimizer
    params = list(x_transform.nn.parameters()) + list(y_transform.nn.parameters())
    optimizer = optim.SGD(params, lr=lr)

    writer = SummaryWriter(log_dir=out_dir)
    # Main optimization loop
    for rnd in range(n_rounds):
      #indices = self._get_indices()

      # Find approximate solution of subproblem at fixed lmbda
      for opt_stp in range(opt_steps):
        iter_idx = opt_steps * rnd + opt_stp

        # Compute augmented Lagrangian
        obj = self.objective(y_transform, xstar)
        constr = self.constraints((x_transform, y_transform), results, opt_stp)

        results['objective_every_step'].append(obj.item())
        results['lhsmrhs_mean'].append(_config['slack'] - constr.item()) #Only valid for inq constraints

        case1 = - lmbda * constr + 0.5 * tau * constr**2
        case2 = - 0.5 * lmbda**2 / tau
        #pdb.set_trace()
        #temp_log = [constr.type(), case1.type(), case2.type()]
        #_log.info(f"Types of tau, lmbda, constr, case1, case2: {temp_log}")
        if(_config['lagrangian_strategy'] == 'inequality'):
            psi = torch.where(tau * constr <= lmbda, case1, case2)
        elif(_config['lagrangian_strategy'] == 'equality'):
            psi = case1
        
        psisum = torch.sum(psi)
        lagrangian = sign*obj + psisum

        #pdb.set_trace()

        # Calculate gradients
        optimizer.zero_grad()
        lagrangian.backward()

        # Some tensorboard logging
        constr_norm = torch.linalg.norm(constr.clone().detach())
        writer.add_scalar("Optimization/objective", obj.item(), iter_idx)
        writer.add_scalar("Optimization/psisum", psisum.item(), iter_idx)
        writer.add_scalar("Optimization/lagrangian", lagrangian.item(), iter_idx)
        writer.add_scalar("Optimization/mmd", _config['slack'] - constr.item(), iter_idx)
        writer.add_scalar("Optimization/tau", tau, iter_idx)
        writer.add_scalar("Optimization/eta", eta, iter_idx)
#        for i, l in enumerate(lmbda.clone().detach()):
#          writer.add_scalar(f"Multipliers/{i}", l, iter_idx)
        for name, param in list(x_transform.nn.named_parameters()) + list(y_transform.nn.named_parameters()):
          writer.add_scalar(f"GradNorm/{name}", param.grad.norm(), iter_idx)

        # Backprop
        if(_config['clip_grad']):
          torch.nn.utils.clip_grad_value_(params, _config['grad_clip_value'])
        optimizer.step()

      # Check current solution to subproblem
      results['objective'].append(obj.item())
      to_log = ['Round', rnd, ':', obj.item(), _config['slack'] - constr.item(), lagrangian.item()]
      _log.info(' '.join(str(_) for _ in to_log))
      detached_constr = constr.clone().detach()
      if(torch.isnan(obj)):
        _log.info("Objective value is now NaN, stopping optimization...")
        break
      # lhsmrhs_normalized = |lhs-rhs|/2*n
      lhsmrhs_normalized = _config['slack'] - detached_constr 
      if(_config['lagrangian_strategy'] == 'inequality'):
          cnorm = torch.linalg.norm(lhsmrhs_normalized)
          is_satisfied = np.all(lhsmrhs_normalized.numpy() <= _config['slack']) #CHECK THIS
          writer.add_scalar("Optimization/LHSmRHS_norm", cnorm.item(), iter_idx)
          if((cnorm < eta or is_satisfied)):
            # Global convergence check
            if(rnd>=1 and _config['global_conv_check']):
              if np.abs(results['objective'][-1] - results['objective'][-2]) <= 0.005:
                _log.info("Global convergence passed")
                break
            #pdb.set_trace()
            lmbda = lmbda.float()
            lmbda -= tau * detached_constr
            lmbda = torch.maximum(torch.tensor(0), lmbda)
            eta = torch.max(torch.tensor([eta / tau ** 0.5, _config['eta_min']]))
          else:
            # Increase penalty parameter, tighten tolerance
            tau = torch.min(torch.tensor([alpha * tau, _config['tau_max']]))
            eta = torch.max(torch.tensor([1 / tau ** 0.1, _config['eta_min']])) 
      elif(_config['lagrangian_strategy'] == 'equality'):
          cnorm = torch.linalg.norm(detached_constr)
          is_satisfied = np.all(_config['slack'] / len(detached_constr) - np.abs(detached_constr.numpy()) >= 0)
          writer.add_scalar("Optimization/LHSmRHS_norm", cnorm.item(), iter_idx)
          if((cnorm < eta or is_satisfied) and _config['global_conv_check']):
            # Global convergence check
            if np.abs(results['objective'][-1] - results['objective'][-2]) <= 0.005:
              break
            lmbda -= tau * detached_constr
            lmbda = torch.maximum(torch.tensor(0), lmbda)
            eta = torch.max(torch.tensor([eta / tau ** 0.5, _config['eta_min']]))
          else:
            # Increase penalty parameter, tighten tolerance
            tau = torch.min(torch.tensor([alpha * tau, _config['tau_max']]))
            eta = 1 / 4 * tau ** 0.1
      
      

    _log.info(f"Finished optimization loop...")
    torch.save(x_transform.nn.state_dict(), os.path.join(out_dir, 'x_transform.ckpt'))
    torch.save(y_transform.nn.state_dict(), os.path.join(out_dir, 'y_transform.ckpt'))

    _log.info(f"Convert all results to numpy arrays...")
    results = {k: np.array(v) for k, v in results.items()}

    #maxabsdiff = np.array([np.max(np.abs(results['lhs'] - r)) for r in results["rhs"]])
    fin_i = np.sum(~np.isnan(results["objective"])) - 1
    _log.info(f"Final non-nan objective at {fin_i}.")
    fin_obj = results["objective"][fin_i]
    #fin_maxabsdiff = maxabsdiff[fin_i]

    #sat_i = [np.all((np.abs((results['lhs'] - r) / results['lhs']) < _config['slack']) |
    #              (np.abs(results['lhs'] - r) < _config['slack_abs']))
    #       for r in results["rhs"]]
    #sat_i = np.where(sat_i)[0]

    # if len(sat_i) > 0:
    #     sat_i = sat_i[-1]
    #     _log.info(f"Final satisfied constraint at {sat_i}.")
    #     sat_obj = results["objective"][sat_i]
    #     sat_maxabsdiff = maxabsdiff[sat_i]
    # else:
    #     sat_i = -1
    #     _log.info(f"Constraints were never satisfied.")
    #     sat_obj, sat_maxabsdiff = np.nan, np.nan

    if _config['store_data']:
      _log.info(f"Save result data to...")
      result_path = os.path.join(out_dir, "results.npz")
      np.savez(result_path, **results)

    _log.info("Finished run.")
    
    return fin_i, fin_obj#, fin_maxabsdiff, sat_i, sat_obj, sat_maxabsdiff

@ex.automain
def main(_config, _log):
  out_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_' + str(hash(tuple(_config.items())))
  out_dir = os.path.join(os.path.abspath(_config['output_dir']), out_name)
  _log.info(f"Save all output to {out_dir}...")
  if not os.path.exists(out_dir):
      os.makedirs(out_dir)

  
  with open(os.path.join(out_dir, 'config.yaml'), 'w') as fp:
    yaml.dump(_config, fp)

  _log.info(f"Get dataset: {_config['dataset']}")
  key_data = random.PRNGKey(int(_config['seed_data']))
  if _config['dataset'] == "demand_design":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar, xstar_grid,\
    xstar_grid_plotting, xstar_plotting = data.get_hllt(subkey_data, _config['num_data'], 0.5,
                                          _config['hllt_feature'], _config['num_xstar'], _config['hllt_additive'])
  elif _config['dataset'] == "mz":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar, xstar_grid,\
    xstar_grid_plotting, xstar_plotting = data.get_mz(subkey_data, _config['num_data'], 
                                          _config['hllt_feature'], _config['num_xstar'])
  elif _config['dataset'] == "scalar":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar = data.get_synth_data(
    subkey_data, _config['num_data'], _config['equations'],
    disconnect_instrument=False)
    xstar_plotting = data_xstar
    xmin, xmax = np.min(dat['x']), np.max(dat['x'])
    xstar_grid = np.linspace(xmin, xmax, _config['num_xstar'] + 1)
    xstar_grid = (xstar_grid[:-1] + xstar_grid[1:]) / 2
    xstar_grid_plotting = xstar_grid
    data_xstar = data_xstar[:,np.newaxis]
  elif _config['dataset'] == "synth-2d":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar, xstar_grid,\
    xstar_grid_plotting, xstar_plotting = data.get_synth_data_2d(
    subkey_data, _config['num_data'], _config['equations'],
    _config['xstar_axis'], _config['num_xstar'])
  elif _config['dataset'] == "synth-3d":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar, xstar_grid,\
    xstar_grid_plotting, xstar_plotting = data.get_synth_data_3d(
    subkey_data, _config['num_data'], _config['equations'],
    _config['xstar_axis'], _config['num_xstar'])
  elif _config['dataset'] == "yeast":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar, xstar_grid,\
    xstar_grid_plotting, xstar_plotting = data.get_yeast_data(subkey_data)
  else:
      raise ValueError(f"Unknown dataset {_config['dataset']}")
  dat['x'], dat['y'], dat['z']= utils.jax_to_torch(dat['x']), utils.jax_to_torch(dat['y']), utils.jax_to_torch(dat['z'])
  dat = utils.standardize_data_shapes(dat)
  # Shapes
  # X=(num_data, dim_x), Z=(num_data, dim_z), Y=(num_data,)

  # _log.info(f"Get the basis function...")
  # if(_config['response_type'] == 'mlp'):
  #   y_given_x =  utils.get_basis(_config['dim_x'], _config['dim_theta'], dat['x'], dat['y'], _config['dataset'], \
  #                                 _config['equations'], out_dir)
  #   psi_vector = y_given_x.basis
  # elif(_config['response_type'] == 'polynomial'):
  #   psi_vector = utils.psi_polyn(_config['dim_x'], _config['dim_theta'])
  #   y_given_x =  utils.get_basis(_config['dim_x'], _config['dim_theta'], dat['x'], dat['y'], _config['dataset'], \
  #                                 _config['equations'], out_dir)

  # _log.info("Get inferred N...")
  # n_inferred = utils.get_n_inferred(_config['dim_z'], _config['dim_x'], dat['z'], dat['x'], _config['dataset'], \
  #                                   _config['equations'], _config['model_xz'], out_dir)
  # indices_to_remove = utils.low_likelihood_indices(n_inferred, _config['dim_x'], _config['num_to_remove'])

  # We will be using the same set of indices for all xstars in the grid
  # _log.info(f"Get the constraint indices to be used for subsampling...")
  # constr_indices = utils.get_indices(dat, _config['num_constant_samples'])
  # constr_indices = utils.tensor_difference(constr_indices, indices_to_remove)
  # n_constraints = 2*len(constr_indices)

  #_log.info("Get lhs...")
  #zx = torch.cat((dat['z'], dat['x']), 1)
  #lhs, lhs_all = utils.get_lhs(_config['dim_x'] + _config['dim_z'], zx, dat['y'], constr_indices,\
  #                             _config['dataset'], _config['equations'], out_dir)  
  
  objective = Objective(_config['dim_x'], _config['store_data'])
  constraints = Constraints(dat, _config['slack'], _config['lagrangian_strategy'], _config['num_samples_mmd'], \
                           _config['save_rhs_every_step'], _config['store_data'], )
  optim = AugmentedLagrangian(dat, objective, constraints, out_dir, _config['tau_init'], _config['tau_factor'])
  
  results_global = {}
  y_given_x =  utils.get_basis(_config['dim_x'], _config['dim_theta'], dat['x'], dat['y'], _config['dataset'], \
                                  _config['equations'], out_dir)
  expected_y_given_x = y_given_x(torch.tensor(np.array(data_xstar))).detach().numpy()
  if(_config['store_data']):
    results_global['expected_y_given_x'] = expected_y_given_x

    result_path = os.path.join(out_dir, "results_global.npz")
    np.savez(result_path, **results_global)
  final = {
      "indices": np.zeros((_config['num_xstar'], 2), dtype=np.int32),
      "objective": jnp.zeros((_config['num_xstar'], 2)),
      "maxabsdiff": np.zeros((_config['num_xstar'], 2)),
  }
  satis = {
      "indices": np.zeros((_config['num_xstar'], 2), dtype=np.int32),
      "objective": np.zeros((_config['num_xstar'], 2)),
      "maxabsdiff": np.zeros((_config['num_xstar'], 2)),
  }
  # ---------------------------------------------------------------------------
  # Main loops over xstar and bounds
  # ---------------------------------------------------------------------------
  for i, xstar in enumerate(xstar_grid):
      xstar = utils.jax_to_torch(xstar).type(torch.FloatTensor)
      if(xstar.ndim == 0):
        xstar = torch.tensor([xstar])
      for j, bound in enumerate(["lower", "upper"]):
          _log.info(f"Run xstar={xstar_grid_plotting[i]}, bound={bound}...")
          vis = "=" * 10
          _log.info(f"{vis} {i * 2 + j + 1}/{2 * _config['num_xstar']} {vis}")
          fin_i, fin_obj = \
                                  optim.optimize(bound, xstar, xstar_grid_plotting[i], \
                                  n_rounds=_config['num_rounds'], opt_steps=_config['opt_steps'], \
                                  lr=_config['lr'])
          final["indices"] = index_update(final["indices"], (i, j), fin_i)
          final["objective"] = index_update(final["objective"], (i, j), fin_obj)
        

  if(_config['plot_final']):
    _log.info(f"Saving the final plots...")
    plotting.plot_final_bounds(xstar_grid_plotting, final["objective"], \
                               xstar_plotting, data_ystar, expected_y_given_x, out_dir)

