import json
import yaml
import itertools
import os

from datetime import datetime


import numpy as np
import jax.numpy as jnp
from jax import random
from jax.ops import index_update

from tqdm import tqdm

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.autograd import Variable

import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T

import logging

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule

import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter

from utils import data
from utils import plotting
from utils import utils

#import pdb

from sacred import Experiment


ex = Experiment()

ex.add_config("configs/fd_multi_constr.yaml")

logger = logging.getLogger("mylogger")
logger.setLevel("INFO")

# # attach it to the experiment
ex.logger = logger


class Objective:
  """
  A class that represents the objective.
  We'll make it callable via the __call__ function.
  """

  def __init__(self, dim_m, psi_vector, x_transform, store_data, \
               sample_size_nx=200, sample_size_nm=50):
    self.distr = MultivariateNormal(torch.zeros(dim_m),
                                    torch.eye(dim_m))
    self.sample_size_nx = sample_size_nx
    self.sample_size_nm = sample_size_nm
    self.psi_vector = psi_vector
    self.x_transform = x_transform
    self.flow = dist.ConditionalTransformedDistribution(self.distr, [self.x_transform])
    self.store_data = store_data

  def __call__(self, model, x, xstar, results):
    # Sample n_x from data['x']
    #  sample flow(.|x*)
    #x = data['x']
    idx = torch.randint(0, len(x), (self.sample_size_nx, ))
    nx = x[idx]
    mu = torch.Tensor()
    psi_mstar = torch.Tensor()
    for nx_ in nx:
      nm_samples = self.distr.sample((self.sample_size_nm, ))
      m_samples = self.x_transform.condition(xstar)(nm_samples)
      nx_repeated_ = nx_.repeat(self.sample_size_nm, 1)
      psi_mstar_ = self.psi_vector(m_samples).detach()
      model_input = torch.cat((nx_repeated_, nm_samples), 1)
      mu_, _ = model(model_input)

      mu = torch.cat((mu, mu_), 0)
      psi_mstar = torch.cat((psi_mstar, psi_mstar_), 0)
      
    return torch.mean(torch.einsum("bi, bi -> b", psi_mstar, mu))

class Constraints:
  """
  A class that represents the constraints.
  We'll make it callable via the __call__ function.
  """

  def __init__(self, lhs, n_inferred, psi_vector, dim_theta, constr_indices, slack,\
               lagrangian_strategy, save_rhs_every_step, store_data):
    """Initialize the Constraints class.

    """
    self.lhs = lhs
    self.n_inferred = n_inferred
    self.psi_vector = psi_vector
    self.dim_theta = dim_theta
    self.constr_indices = constr_indices
    self.n_constraints = 2 * len(constr_indices)
    self.slack = slack
    self.lagrangian_strategy = lagrangian_strategy
    self.save_rhs_every_step = save_rhs_every_step
    self.store_data = store_data

  def __call__(self, model, results, opt_stp, data):
    """Constraints at current parameters and for given indices."""
    model_input = torch.cat((data['x'], self.n_inferred), 1)
    mu_theta, vec_theta = model(model_input[self.constr_indices])
    sigma_theta = utils.vec_to_sigma(vec_theta, self.dim_theta)

    psi_m = self.psi_vector(data['m'][self.constr_indices]).detach()
    # Double check whether dimensions are such that this really is the desired dot product.
    # If we can write it as matmul would be nicer
    rhs1 = torch.sum(psi_m * mu_theta, 1).squeeze()

    psi_m = psi_m.unsqueeze(1)
    mu_theta = mu_theta.unsqueeze(1)
    sig_adj = sigma_theta + torch.matmul(torch.transpose(mu_theta, 1, 2),
                                         mu_theta)
    rhs2 = torch.matmul(psi_m, sig_adj)
    rhs2 = torch.matmul(rhs2, torch.transpose(psi_m, 1, 2)).squeeze()
    rhs = torch.cat((rhs1, rhs2), 0)

    constr = self.lhs - rhs

    if(self.lagrangian_strategy == 'inequality'):
        constrs = self.slack - torch.abs(constr)
    elif(self.lagrangian_strategy == 'equality'):
        constrs = constr

    if(opt_stp == 0):
      results['rhs'].append(rhs.clone().detach().numpy())
    if(self.store_data):
      if(self.save_rhs_every_step == True):
        results['all_rhs'].append(rhs.clone().detach().numpy())
      lhsmrhs = torch.abs(self.lhs - rhs.clone().detach())
      satis_status = lhsmrhs <= self.slack
      results["lhsmrhs_mean"].append(lhsmrhs.mean().item())
      results["lhsmrhs_min"].append(lhsmrhs.min().item())
      results["lhsmrhs_max"].append(lhsmrhs.max().item())
      results["lhsmrhs_norm"].append(torch.linalg.norm(lhsmrhs).item())
      results["satis_frac"].append(satis_status.sum().item()/(2*self.n_constraints))
      results["lhs"] = self.lhs
    return constrs / len(self.constr_indices)


class AugmentedLagrangian:
  """
  Augmented Lagrangian method with equality constraints
  for an objective that depends on data and with the option
  to subsample constraints.
  """

  def __init__(self,
               data,
               objective,
               constraints,
               psi_vector,
               n_inferred,
               lhs_all,
               log_dir,
               n_constraints,
               tau_init=10,
               alpha=100):
    """Initialize the augmented Langrangian method.

    """
    self.data = data
    self.objective = objective
    self.constraints = constraints
    self.n_constraints = n_constraints
    self.psi_vector = psi_vector
    self.n_inferred = n_inferred
    self.lhs_all = lhs_all
    self.log_dir = log_dir
    self.tau_init = tau_init
    self.alpha = alpha

  @ex.capture
  def optimize(self,
               bound,
               xstar,
               xstar_name,
               _config,
               _log,
               _run,
               n_rounds=100,
               opt_steps=50,
               lr=0.01):
    """Run the augmented Lagrangian optimization."""
    # Initializing parameters

    out_dir = os.path.join(self.log_dir, f"{bound}-xstar_{xstar_name}")
    if _config['store_data']:
      _log.info(f"Current run output directory: {out_dir}...")
    if not os.path.exists(out_dir):
      os.makedirs(out_dir)

    _log.info(f"Evaluate at xstar={xstar}...")
    _log.info(f"Evaluate {bound} bound...")

    _log.info(f"Initialize model and weights...")
    torch.manual_seed(_config['seed_method'])
    model = utils.Gaussian(_config['dim_x'] + _config['dim_m'], _config['dim_theta'])
    n_inferred_w_nx = torch.cat((self.data['x'], self.n_inferred), 1) 
    model = utils.initialize_model_weights(self.data['m'], model, self.psi_vector, self.lhs_all,\
                                           n_inferred_w_nx, _config['dim_theta'])

    _log.info(f"Initialize dictionary for results...")
    results = {
        "rhs": [],
        "all_rhs": [],
        "objective": [],
        "objective_every_step": [],
        "lhsmrhs_mean": [],
        "lhsmrhs_max": [],
        "lhsmrhs_min": [],
        "lhsmrhs_norm": [],
        "satis_frac": []
    }
    _log.info(f"Store xstar value for easier cumulative analysis...")
    results['xstar'] = xstar
    results['xstar_name'] = xstar_name
    results['bound'] = bound
    
    sign = 1 if bound == "lower" else -1
    tau = self.tau_init
    alpha = self.alpha
    eta = 1 / tau ** 0.1
    lmbda = torch.ones(self.n_constraints)

    # Setup optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr)

    writer = SummaryWriter(log_dir=out_dir)
    # Main optimization loop
    for rnd in range(n_rounds):
      #indices = self._get_indices()

      # Find approximate solution of subproblem at fixed lmbda
      for opt_stp in range(opt_steps):
        iter_idx = opt_steps * rnd + opt_stp

        # Compute augmented Lagrangian
        obj = self.objective(model, self.data['x'], xstar, results)
        constr = self.constraints(model, results, opt_stp, self.data)

        results['objective_every_step'].append(obj.item())

        case1 = - lmbda * constr + 0.5 * tau * constr**2
        case2 = - 0.5 * lmbda**2 / tau
        if(_config['lagrangian_strategy'] == 'inequality'):
            psi = torch.where(tau * constr <= lmbda, case1, case2)
        elif(_config['lagrangian_strategy'] == 'equality'):
            psi = case1
        
        psisum = torch.sum(psi)
        lagrangian = sign*obj + psisum

        #pdb.set_trace()

        # Calculate gradients
        optimizer.zero_grad()
        lagrangian.backward()

        # Some tensorboard logging
        constr_norm = torch.linalg.norm(constr.clone().detach())
        writer.add_scalar("Optimization/objective", obj.item(), iter_idx)
        writer.add_scalar("Optimization/psisum", psisum.item(), iter_idx)
        writer.add_scalar("Optimization/lagrangian", lagrangian.item(), iter_idx)
        writer.add_scalar("Optimization/constraint_norm", constr_norm, iter_idx)
        writer.add_scalar("Optimization/tau", tau, iter_idx)
        writer.add_scalar("Optimization/eta", eta, iter_idx)
        for i, l in enumerate(lmbda.clone().detach()):
          writer.add_scalar(f"Multipliers/{i}", l, iter_idx)
        for name, param in model.named_parameters():
          writer.add_scalar(f"GradNorm/{name}", param.grad.norm(), iter_idx)

        # Backprop
        if(_config['clip_grad']):
          torch.nn.utils.clip_grad_value_(model.parameters(), _config['grad_clip_value'])
        optimizer.step()

      # Check current solution to subproblem
      results['objective'].append(obj.item())
      to_log = ['Round', rnd, ':', obj.item(), constr_norm.item(), lagrangian.item()]
      _log.info(' '.join(str(_) for _ in to_log))
      detached_constr = constr.clone().detach()
      if(torch.isnan(obj)):
        _log.info("Objective value is now NaN, stopping optimization...")
        break
      # lhsmrhs_normalized = |lhs-rhs|/2*n
      lhsmrhs_normalized = _config['slack'] / len(detached_constr) - detached_constr 
      if(_config['lagrangian_strategy'] == 'inequality'):
          cnorm = torch.linalg.norm(lhsmrhs_normalized)
          is_satisfied = np.all(lhsmrhs_normalized.numpy() >= 0)
          writer.add_scalar("Optimization/LHSmRHS_norm", cnorm.item(), iter_idx)
          if((cnorm < eta or is_satisfied) and _config['global_conv_check']):
            # Global convergence check
            if(rnd>=1):
              if np.abs(results['objective'][-1] - results['objective'][-2]) <= 0.005:
                _log.info("Global convergence passed")
                break
            lmbda -= tau * detached_constr
            lmbda = torch.maximum(torch.tensor(0), lmbda)
            eta = torch.max(torch.tensor([eta / tau ** 0.5, _config['eta_min']]))
          else:
            # Increase penalty parameter, tighten tolerance
            tau = torch.min(torch.tensor([alpha * tau, _config['tau_max']]))
            eta = 1 / tau ** 0.1
      elif(_config['lagrangian_strategy'] == 'equality'):
          cnorm = torch.linalg.norm(detached_constr)
          is_satisfied = np.all(_config['slack'] / len(detached_constr) - np.abs(detached_constr.numpy()) >= 0)
          writer.add_scalar("Optimization/LHSmRHS_norm", cnorm.item(), iter_idx)
          if((cnorm < eta or is_satisfied) and _config['global_conv_check']):
            # Global convergence check
            if np.abs(results['objective'][-1] - results['objective'][-2]) <= 0.005:
              break
            lmbda -= tau * detached_constr
            lmbda = torch.maximum(torch.tensor(0), lmbda)
            eta = torch.max(torch.tensor([eta / tau ** 0.5, _config['eta_min']]))
          else:
            # Increase penalty parameter, tighten tolerance
            tau = torch.min(torch.tensor([alpha * tau, _config['tau_max']]))
            eta = 1 / tau ** 0.1
      
      

    _log.info(f"Finished optimization loop...")
    torch.save(model.state_dict(), os.path.join(out_dir, 'model_theta.ckpt'))

    _log.info(f"Convert all results to numpy arrays...")
    results = {k: np.array(v) for k, v in results.items()}

    if _config['store_data']:
      _log.info(f"Save result data to...")
      result_path = os.path.join(out_dir, "results.npz")
      np.savez(result_path, **results)

    maxabsdiff = np.array([np.max(np.abs(results['lhs'] - r)) for r in results["rhs"]])
    fin_i = np.sum(~np.isnan(results["objective"])) - 1
    _log.info(f"Final non-nan objective at {fin_i}.")
    fin_obj = results["objective"][fin_i]
    fin_maxabsdiff = maxabsdiff[fin_i]

    sat_i = [np.all((np.abs((results['lhs'] - r) / results['lhs']) < _config['slack']) |
                  (np.abs(results['lhs'] - r) < _config['slack_abs']))
           for r in results["rhs"]]
    sat_i = np.where(sat_i)[0]

    if len(sat_i) > 0:
        sat_i = sat_i[-1]
        _log.info(f"Final satisfied constraint at {sat_i}.")
        sat_obj = results["objective"][sat_i]
        sat_maxabsdiff = maxabsdiff[sat_i]
    else:
        sat_i = -1
        _log.info(f"Constraints were never satisfied.")
        sat_obj, sat_maxabsdiff = np.nan, np.nan

    _log.info("Finished run.")
    
    return fin_i, fin_obj, fin_maxabsdiff, sat_i, sat_obj, sat_maxabsdiff


@ex.config
def my_config():
  output_name = 'fd_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")


@ex.automain
def main(_config, _log):
  out_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_' + str(hash(tuple(_config.items())))
  out_dir = os.path.join(os.path.abspath(_config['output_dir']), out_name)
  _log.info(f"Save all output to {out_dir}...")
  if not os.path.exists(out_dir):
      os.makedirs(out_dir)

  with open(os.path.join(out_dir, 'config.yaml'), 'w') as fp:
    yaml.dump(_config, fp)

  _log.info(f"Get dataset: {_config['dataset']}")
  key_data = random.PRNGKey(0)
  if _config['dataset'] == "scalar":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar = data.fd_get_synth_data(
    subkey_data, _config['num_data'], _config['equations'],
    disconnect_instrument=False)
    xstar_plotting = data_xstar
    xmin, xmax = np.min(dat['x']), np.max(dat['x'])
    xstar_grid = np.linspace(xmin, xmax, _config['num_xstar'] + 1)
    xstar_grid = (xstar_grid[:-1] + xstar_grid[1:]) / 2
    xstar_grid_plotting = xstar_grid
  elif _config['dataset'] == "synth-2d":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar, xstar_grid,\
    xstar_grid_plotting, xstar_plotting = data.fd_get_synth_data_2d(
    subkey_data, _config['num_data'], _config['equations'],
    _config['xstar_axis'], _config['num_xstar'])
  elif _config['dataset'] == "synth-3d":
    key_data, subkey_data = random.split(key_data)
    dat, data_xstar, data_ystar, xstar_grid,\
    xstar_grid_plotting, xstar_plotting = data.fd_get_synth_data_3d(
    subkey_data, _config['num_data'], _config['equations'],
    _config['xstar_axis'], _config['num_xstar'])
  else:
      raise ValueError(f"Unknown dataset {_config['dataset']}")
  dat['x'], dat['y'], dat['m']= utils.jax_to_torch(dat['x']), utils.jax_to_torch(dat['y']), utils.jax_to_torch(dat['m'])
  dat = utils.standardize_data_shapes(dat)

  y_given_x =  utils.get_basis(_config['dim_x'], _config['dim_theta'], dat['x'], dat['y'], _config['dataset'], \
                                  _config['equations'], out_dir)
  # Shapes
  # X=(num_data, dim_x), Z=(num_data, dim_z), Y=(num_data,)

  _log.info(f"Get the basis function...")
  if(_config['response_type'] == 'mlp'):
    psi_vector =  utils.get_basis(_config['dim_m'], _config['dim_theta'], dat['m'], dat['y'], _config['dataset'], _config['equations'], out_dir).basis
  elif(_config['response_type'] == 'polynomial'):
    psi_vector = utils.psi_polyn(_config['dim_m'], _config['dim_theta'])

  _log.info("Get inferred N...")
  n_inferred, flow = utils.get_n_inferred(_config['dim_x'], _config['dim_m'], dat['x'], dat['m'], _config['dataset'], \
                                    _config['equations'], _config['model_mx'], out_dir, True)
  indices_to_remove = utils.low_likelihood_indices(n_inferred, _config['dim_m'], _config['num_to_remove'])

  # We will be using the same set of indices for all xstars in the grid
  _log.info(f"Get the constraint indices to be used for subsampling...")
  constr_indices = utils.get_indices(dat, _config['num_constant_samples'])
  constr_indices = utils.tensor_difference(constr_indices, indices_to_remove)
  n_constraints = 2*len(constr_indices)

  _log.info("Get lhs...")
  xm = torch.cat((dat['x'], dat['m']), 1)
  lhs, lhs_all = utils.get_lhs(_config['dim_x'] + _config['dim_m'], xm, dat['y'], constr_indices, \
                      _config['dataset'], _config['equations'], out_dir)  
  
  objective = Objective(_config['dim_x'], psi_vector, flow, _config['store_data'])
  constraints = Constraints(lhs, n_inferred, psi_vector, _config['dim_theta'], constr_indices, _config['slack'],\
                            _config['lagrangian_strategy'], _config['save_rhs_every_step'], _config['store_data'])
  optim = AugmentedLagrangian(dat, objective, constraints, psi_vector, n_inferred, lhs_all,\
                              out_dir, n_constraints, _config['tau_init'], _config['tau_factor'])
  
  results_global = {}
  expected_y_given_x = y_given_x(torch.tensor(np.array(data_xstar))).detach().numpy()
  if(_config['store_data']):
    results_global['lhs'] = lhs.clone().detach().numpy()
    results_global['n_inferred'] = n_inferred.clone().detach().numpy()
    results_global['constr_indices'] = constr_indices.numpy()
    results_global['expected_y_given_x'] = expected_y_given_x

    result_path = os.path.join(out_dir, "results_global.npz")
    np.savez(result_path, **results_global)
  
  final = {
      "indices": np.zeros((_config['num_xstar'], 2), dtype=np.int32),
      "objective": jnp.zeros((_config['num_xstar'], 2)),
      "maxabsdiff": np.zeros((_config['num_xstar'], 2)),
  }
  satis = {
      "indices": np.zeros((_config['num_xstar'], 2), dtype=np.int32),
      "objective": np.zeros((_config['num_xstar'], 2)),
      "maxabsdiff": np.zeros((_config['num_xstar'], 2)),
  }
  # ---------------------------------------------------------------------------
  # Main loops over xstar and bounds
  # ---------------------------------------------------------------------------
  for i, xstar in enumerate(xstar_grid):
      xstar = utils.jax_to_torch(xstar).type(torch.FloatTensor)
      if(xstar.ndim == 0):
        xstar = torch.tensor([xstar])
      for j, bound in enumerate(["lower", "upper"]):
          _log.info(f"Run xstar={xstar_grid_plotting[i]}, bound={bound}...")
          vis = "=" * 10
          _log.info(f"{vis} {i * 2 + j + 1}/{2 * _config['num_xstar']} {vis}")
          fin_i, fin_obj, fin_diff, sat_i, sat_obj, sat_diff = \
                                  optim.optimize(bound, xstar, xstar_grid_plotting[i], \
                                  n_rounds=_config['num_rounds'], opt_steps=_config['opt_steps'], \
                                  lr=_config['lr'])
          final["indices"] = index_update(final["indices"], (i, j), fin_i)
          final["objective"] = index_update(final["objective"], (i, j), fin_obj)
          final["maxabsdiff"] = index_update(final["maxabsdiff"], (i, j), fin_diff)
          satis["indices"] = index_update(satis["indices"], (i, j), sat_i)
          satis["objective"] = index_update(satis["objective"], (i, j), sat_obj)
          satis["maxabsdiff"] = index_update(satis["maxabsdiff"], (i, j), sat_diff)

  if(_config['plot_final']):
    _log.info(f"Saving the final plots...")
    plotting.plot_final_bounds(xstar_grid_plotting, final["objective"], \
                               xstar_plotting, data_ystar, expected_y_given_x, out_dir)

