import torch
import numpy as np
import abc
from tqdm import trange

from losses import get_score_fn
from utils.graph_utils import mask_adjs, mask_x, gen_noise
from sde import VPSDE, subVPSDE

# from projop import project
from projop.project_bisection import project, drifted_project, satisfies

def get_cuda_memory(device):
  # device = int(device.split(":")[1])
  t = torch.cuda.get_device_properties(device).total_memory
  r = torch.cuda.memory_reserved(device)
  a = torch.cuda.memory_allocated(device)
  f = r-a  # free inside reserved
  t, r, a, f = t // (1024 ** 2), r // (1024 ** 2), a // (1024 ** 2), f // (1024 ** 2)
  return f"Total: {t}, Reserved: {r}, Allocated: {a}, Free: {f}"


class Predictor(abc.ABC):
  """The abstract class for a predictor algorithm."""
  def __init__(self, sde, score_fn, probability_flow=False):
    super().__init__()
    self.sde = sde
    # Compute the reverse SDE/ODE
    self.rsde = sde.reverse(score_fn, probability_flow)
    self.score_fn = score_fn

  @abc.abstractmethod
  def update_fn(self, x, t, flags):
    pass


class Corrector(abc.ABC):
  """The abstract class for a corrector algorithm."""
  def __init__(self, sde, score_fn, snr, scale_eps, n_steps):
    super().__init__()
    self.sde = sde
    self.score_fn = score_fn
    self.snr = snr
    self.scale_eps = scale_eps
    self.n_steps = n_steps

  @abc.abstractmethod
  def update_fn(self, x, t, flags):
    pass


class EulerMaruyamaPredictor(Predictor):
  def __init__(self, obj, sde, score_fn, probability_flow=False):
    super().__init__(sde, score_fn, probability_flow)
    self.obj = obj

  def update_fn(self, x, adj, flags, t):
    dt = -1. / self.rsde.N

    if self.obj=='x':
      z = gen_noise(x, flags, sym=False)
      drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=False)
      x_mean = x + drift * dt
      x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z
      return x, x_mean

    elif self.obj=='adj':
      z = gen_noise(adj, flags)
      drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=True)
      adj_mean = adj + drift * dt
      adj = adj_mean + diffusion[:, None, None] * np.sqrt(-dt) * z

      return adj, adj_mean

    else:
      raise NotImplementedError(f"obj {self.obj} not yet supported.")


class ReverseDiffusionPredictor(Predictor):
  def __init__(self, obj, sde, score_fn, probability_flow=False):
    super().__init__(sde, score_fn, probability_flow)
    self.obj = obj

  def update_fn(self, x, adj, flags, t):

    if self.obj == 'x':
      f, G = self.rsde.discretize(x, adj, flags, t, is_adj=False)
      z = gen_noise(x, flags, sym=False)
      x_mean = x - f
      x = x_mean + G[:, None, None] * z
      return x, x_mean

    elif self.obj == 'adj':
      f, G = self.rsde.discretize(x, adj, flags, t, is_adj=True)
      z = gen_noise(adj, flags)
      adj_mean = adj - f
      adj = adj_mean + G[:, None, None] * z
      return adj, adj_mean
    
    else:
      raise NotImplementedError(f"obj {self.obj} not yet supported.")


class NoneCorrector(Corrector):
  """An empty corrector that does nothing."""

  def __init__(self, obj, sde, score_fn, snr, scale_eps, n_steps):
    self.obj = obj
    pass

  def update_fn(self, x, adj, flags, t):
    if self.obj == 'x':
      return x, x
    elif self.obj == 'adj':
      return adj, adj
    else:
      raise NotImplementedError(f"obj {self.obj} not yet supported.")


class LangevinCorrector(Corrector):
  def __init__(self, obj, sde, score_fn, snr, scale_eps, n_steps, constr_op='None', constr_config=None):
    super().__init__(sde, score_fn, snr, scale_eps, n_steps)
    self.obj = obj
    self.constr_op = constr_op
    self.constr_config = constr_config

  def update_fn(self, x, adj, flags, t):
    sde = self.sde
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
    seps = self.scale_eps
    if self.constr_op == 'cond_bump':
      from cond_baseline.bump import projection_flow
      grad_cond_x, grad_cond_adj = projection_flow (x, adj, self.constr_config)
    elif self.constr_op == 'cond_guid':
      from cond_baseline.guidance import guidance_grad
      grad_cond_x, grad_cond_adj = guidance_grad (x, adj, self.constr_config)

    if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
      timestep = (t * (sde.N - 1) / sde.T).long()
      alpha = sde.alphas.to(t.device)[timestep]
    else:
      alpha = torch.ones_like(t)

    if self.obj == 'x':
      for i in range(n_steps):
        grad = score_fn(x, adj, flags, t)
        grad += grad_cond_x if 'cond' in self.constr_op else 0.
        noise = gen_noise(x, flags, sym=False)
        grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
        step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
        x_mean = x + step_size[:, None, None] * grad
        x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps
      return x, x_mean

    elif self.obj == 'adj':
      for i in range(n_steps):
        grad = score_fn(x, adj, flags, t)
        grad += grad_cond_adj if 'cond' in self.constr_op else 0.
        noise = gen_noise(adj, flags)
        grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
        step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
        adj_mean = adj + step_size[:, None, None] * grad
        adj = adj_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps
      return adj, adj_mean

    else:
      raise NotImplementedError(f"obj {self.obj} not yet supported")

# -------- PC sampler --------
def get_pc_sampler(sde_x, sde_adj, shape_x, shape_adj, predictor='Euler', corrector='None', 
                   snr=0.1, scale_eps=1.0, n_steps=1, 
                   probability_flow=False, continuous=False,
                   denoise=True, eps=1e-3, constr_config=None, config=None,
                   device='cuda'):

  def pc_sampler(model_x, model_adj, init_flags):

    score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)
    score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)

    predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor 
    corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector

    predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)
    corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)

    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)
    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps, 
                                      constr_op=constr_config.method.op, constr_config=constr_config)

    def corrector_input (x, adj, x_mean, adj_mean, vec_t):
      _x = x
      x, x_mean = corrector_obj_x.update_fn(x, adj, flags, vec_t)
      adj, adj_mean = corrector_obj_adj.update_fn(_x, adj, flags, vec_t)
      return x, adj, x_mean, adj_mean
    
    def predictor_input (x, adj, x_mean, adj_mean, vec_t):
      _x = x
      x, x_mean = predictor_obj_x.update_fn(x, adj, flags, vec_t)
      adj, adj_mean = predictor_obj_adj.update_fn(_x, adj, flags, vec_t)
      return x, adj, x_mean, adj_mean

    def project_input (x, adj, x_mean, adj_mean, vec_t):
      if constr_config.constraint != 'None':
        if "method" not in constr_config or constr_config.method.op == "proj":
          x, adj = drifted_project (x, adj, constr_config)
          x_mean, adj_mean = drifted_project (x_mean, adj_mean, constr_config)
        elif "method" in constr_config and constr_config.method.op == "prox":
          raise NotImplementedError("prox not implemented yet")
      return x, adj, x_mean, adj_mean


    funcs = {'c': corrector_input, 'p': predictor_input, 'j': project_input}
    # print (constr_config.method.gamma)

    with torch.no_grad():
      # -------- Initial sample --------
      x = sde_x.prior_sampling(shape_x).to(device) 
      adj = sde_adj.prior_sampling_sym(shape_adj).to(device) 
      flags = init_flags
      x = mask_x(x, flags)
      adj = mask_adjs(adj, flags)
      diff_steps = sde_adj.N + constr_config.add_diff_step
      timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device)

      # -------- Reverse diffusion process --------
      # av_constr_val = 0
      for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):
        t = timesteps[i]
        vec_t = torch.ones(shape_adj[0], device=t.device) * t
        constr_config.method["solve_order"] = 'cpj' if 'solve_order' not in constr_config.method else constr_config.method.solve_order
        burning = False
        if "burnin" in constr_config:
          if i >= constr_config.burnin:
            solve_order = constr_config.method.solve_order
            burning = False
          else:
            solve_order = "cp"
            burning = True

        if "gamma" in constr_config.schedule and not burning:
          if constr_config.schedule.gamma == "cyclical":
            if i % constr_config.schedule.params[0] == 0:
              solve_order = constr_config.method.solve_order
            else:
              solve_order = "cp"
          elif constr_config.schedule.gamma == "annealing":
            # check.
            pass
          elif constr_config.schedule.gamma == "steprise":
            # not useful as with gamma_mulp = 2, 0.5 ---> 1 in one step.
            # not smooth gradation
            gamma_init, gamma_mulp, gamma_steps = constr_config.schedule.params
            constr_config.method.gamma = min(1., gamma_init * gamma_mulp * (i//gamma_steps + 1))
          elif constr_config.schedule.gamma == "poly":
            # smooth (1-gamma_init) * t^k + gamma_init
            # assumed to reach 1 at the end (i.e., t = 1)
            gamma_init, time_pow = constr_config.schedule.params
            constr_config.method.gamma = (1 - gamma_init) * (i/diff_steps)**time_pow + gamma_init
          elif constr_config.schedule.gamma == "polystep":
            # smooth (1-gamma_init) * t^k + gamma_init
            # assumed to reach 1 at the end (i.e., t = 1)
            gamma_init, time_pow, gamma_step = constr_config.schedule.params
            constr_config.method.gamma = (1 - gamma_init) * (i//gamma_step*(gamma_step/diff_steps))**time_pow + gamma_init
          elif constr_config.schedule.gamma == "polymid":
            # smooth (1-gamma_init) * t^k + gamma_init
            # reaches 1 at some step in the middle 
            gamma_init, time_pow, one_step = constr_config.schedule.params
            if i >= one_step:
              constr_config.method.gamma = 1.
            else:
              constr_config.method.gamma = (1 - gamma_init) * (i/diff_steps)**time_pow/((one_step/diff_steps)**time_pow) + gamma_init
          elif constr_config.schedule.gamma == "fixed":
            constr_config.method.gamma = constr_config.schedule.params[0]
        
        # print (i, constr_config.method.gamma)
        # print ("\n", burning, i, constr_config["schedule"]['params'][0], solve_order)

        x_mean, adj_mean = None, None
        # print ('Before', torch.sum(satisfies(adj, x, constr_config)))

        for j in range(len(solve_order)):
          import time
          start_time = time.time()
          x, adj, x_mean, adj_mean = funcs[solve_order[j]] (x, adj, x_mean, adj_mean, vec_t)
          # print (solve_order[j], time.time() - start_time)
        
          # if constr_config.constraint == "Valency":
          #   av_constr_val += sum([projop.utils.satisfies(adj[i], x[i], constr_config) for i in range(len(adj))]).item()/len(adj)
        # print ('After', torch.sum(satisfies(adj, x, constr_config)))
        # exit()
          # print ('After', np.sum([torch.all(adj[i] == adj[i].T).item() for i in range(adj.shape[0])]))
          # print ([torch.norm(adj[i], p=1) for i in range(adj.shape[0]) 
          #           if (not (satisfies(adj[i], x[i], constr_config).item()))])
          # assert (torch.all(torch.tensor([satisfies(adj[i], x[i], constr_config) for i in range(adj.shape[0])])))
          # print ("Av constr val", av_constr_val / diff_steps)
          # print ('Before', np.sum([torch.all(adj[i] == adj[i].T).item() for i in range(adj.shape[0])]))
          # print ('Before', np.sum([satisfies(adj[i], x[i], constr_config).item() for i in range(adj.shape[0])]))
          # Project
      return (x_mean if denoise else x), (adj_mean if denoise else adj), diff_steps * (n_steps + 1)
  return pc_sampler


# -------- S4 solver --------
def S4_solver(sde_x, sde_adj, shape_x, shape_adj, predictor='None', corrector='None', 
              snr=0.1, scale_eps=1.0, n_steps=1, 
              probability_flow=False, continuous=False,
              denoise=True, eps=1e-3, constr_config=None, 
              config=None, adj_vals=[0,1], feat_vals=[],
              device='cuda'):

  def s4_solver(model_x, model_adj, init_flags):

    score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)
    score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)

    with torch.no_grad():
      # -------- Initial sample --------
      x = sde_x.prior_sampling(shape_x).to(device) 
      adj = sde_adj.prior_sampling_sym(shape_adj).to(device) 
      flags = init_flags
      x = mask_x(x, flags)
      adj = mask_adjs(adj, flags)
      diff_steps = sde_adj.N + constr_config.add_diff_step
      timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device)
      dt = -1. / diff_steps

      # -------- Rverse diffusion process --------
      for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):
        t = timesteps[i]
        vec_t = torch.ones(shape_adj[0], device=t.device) * t
        vec_dt = torch.ones(shape_adj[0], device=t.device) * (dt/2) 

        # -------- Score computation --------
        score_x = score_fn_x(x, adj, flags, vec_t)
        score_adj = score_fn_adj(x, adj, flags, vec_t)
        if constr_config.method.op == 'cond_bump':
          from cond_baseline.bump import projection_flow
          grad_cond_x, grad_cond_adj = projection_flow (x, adj, constr_config)
          score_x += grad_cond_x
          score_adj += grad_cond_adj
        elif constr_config.method.op == 'cond_guid':
          from cond_baseline.guidance import guidance_grad
          grad_cond_x, grad_cond_adj = guidance_grad (x, adj, constr_config)
          score_x += grad_cond_x
          score_adj += grad_cond_adj

        Sdrift_x = -sde_x.sde(x, vec_t)[1][:, None, None] ** 2 * score_x
        Sdrift_adj  = -sde_adj.sde(adj, vec_t)[1][:, None, None] ** 2 * score_adj

        # -------- Correction step --------
        timestep = (vec_t * (sde_x.N - 1) / sde_x.T).long()

        noise = gen_noise(x, flags, sym=False)
        grad_norm = torch.norm(score_x.reshape(score_x.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
        if isinstance(sde_x, VPSDE):
          alpha = sde_x.alphas.to(vec_t.device)[timestep]
        else:
          alpha = torch.ones_like(vec_t)
      
        step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha
        x_mean = x + step_size[:, None, None] * score_x
        x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * scale_eps

        noise = gen_noise(adj, flags)
        grad_norm = torch.norm(score_adj.reshape(score_adj.shape[0], -1), dim=-1).mean()
        noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
        if isinstance(sde_adj, VPSDE):
          alpha = sde_adj.alphas.to(vec_t.device)[timestep] # VP
        else:
          alpha = torch.ones_like(vec_t) # VE
        step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha
        adj_mean = adj + step_size[:, None, None] * score_adj
        adj = adj_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * scale_eps

        # -------- Prediction step --------
        x_mean = x
        adj_mean = adj
        mu_x, sigma_x = sde_x.transition(x, vec_t, vec_dt)
        mu_adj, sigma_adj = sde_adj.transition(adj, vec_t, vec_dt) 
        x = mu_x + sigma_x[:, None, None] * gen_noise(x, flags, sym=False)
        adj = mu_adj + sigma_adj[:, None, None] * gen_noise(adj, flags)
        
        x = x + Sdrift_x * dt
        adj = adj + Sdrift_adj * dt

        mu_x, sigma_x = sde_x.transition(x, vec_t + vec_dt, vec_dt) 
        mu_adj, sigma_adj = sde_adj.transition(adj, vec_t + vec_dt, vec_dt) 
        x = mu_x + sigma_x[:, None, None] * gen_noise(x, flags, sym=False)
        adj = mu_adj + sigma_adj[:, None, None] * gen_noise(adj, flags)

        x_mean = mu_x
        adj_mean = mu_adj

        # Projection 
        if constr_config.constraint != 'None':
          if "method" not in constr_config or constr_config.method.op == "proj":
            x, adj = drifted_project (x, adj, constr_config)
            x_mean, adj_mean = drifted_project (x_mean, adj_mean, constr_config)
          elif "method" in constr_config and constr_config.method.op == "prox":
            raise NotImplementedError("prox not implemented yet")

      # print(' ')
      return (x_mean if denoise else x), (adj_mean if denoise else adj), 0
  return s4_solver


# Predictor for x
class EulerMaruyamaPredictorSeq(Predictor):
  def __init__(self, obj, sde, score_fn, probability_flow=False):
    super().__init__(sde, score_fn, probability_flow)
    self.obj = obj
    self.score_fn = score_fn

  def rev_sde(self, x, flags, t):
    drift, diffusion = self.sde(x, t) 
    score = self.score_fn(x, flags, t)
    drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
    diffusion = 0. if self.probability_flow else diffusion
    return drift, diffusion

  def update_fn(self, x, flags, t):
    dt = -1. / self.rsde.N
    z = gen_noise(x, flags, sym=False)
    drift, diffusion = self.rev_sde(x, flags, t)
    x_mean = x + drift * dt
    x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z
    return x, x_mean

# Corrector for adj
class LangevinCorrectorSeq(Corrector):
  def __init__(self, obj, sde, score_fn, snr, scale_eps, n_steps):
    super().__init__(sde, score_fn, snr, scale_eps, n_steps)
    self.obj = obj

  def update_fn(self, x, flags, t):
    sde = self.sde
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
    seps = self.scale_eps

    if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
      timestep = (t * (sde.N - 1) / sde.T).long()
      alpha = sde.alphas.to(t.device)[timestep]
    else:
      alpha = torch.ones_like(t)

    for i in range(n_steps):
      grad = score_fn(x, flags, t)
      noise = gen_noise(x, flags, sym=False)
      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
      noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
      step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
      x_mean = x + step_size[:, None, None] * grad
      x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise * seps

      return x, x_mean


def get_pc_sampler_sequential(sde_x, sde_adj, shape_x, shape_adj, predictor='Euler', corrector='None', 
                                snr=0.1, scale_eps=1.0, n_steps=1, 
                                sampling_steps=1,
                                probability_flow=False, continuous=False,
                                denoise=True, eps=1e-3, constr_config=None, 
                                config=None, adj_vals=[0,1], feat_vals=[],
                                device='cuda'):

  def pc_sampler_sequential(model_x, model_adj, init_flags):

    # sde_x.change_discreteization_steps(sampling_steps)
    # sde_adj.change_discreteization_steps(sampling_steps)

    score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)
    score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)

    predictor_fn_x = EulerMaruyamaPredictorSeq
    corrector_fn_x = LangevinCorrectorSeq

    predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor 
    corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector

    predictor_obj_x = predictor_fn_x('x', sde_x, score_fn_x, probability_flow)
    corrector_obj_x = corrector_fn_x('x', sde_x, score_fn_x, snr, scale_eps, n_steps)

    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)
    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)

    with torch.no_grad():
      # Initial sample
      x = sde_x.prior_sampling(shape_x).to(device) 
      adj = sde_adj.prior_sampling_sym(shape_adj).to(device) 

      flags = init_flags

      x = mask_x(x, flags)
      adj = mask_adjs(adj, flags)
      diff_steps = sde_adj.N
      timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device)

      # X first
      for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):
        t = timesteps[i]
        vec_t = torch.ones(shape_adj[0], device=t.device) * t
        x, x_mean = corrector_obj_x.update_fn(x, flags, vec_t)
        x, x_mean = predictor_obj_x.update_fn(x, flags, vec_t)
      print(' ')

      x_gen = x_mean if denoise else x

     # Adj
      for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):
        t = timesteps[i]
        vec_t = torch.ones(shape_adj[0], device=t.device) * t
        adj, adj_mean = corrector_obj_adj.update_fn(x_gen, adj, flags, vec_t)
        adj, adj_mean = predictor_obj_adj.update_fn(x_gen, adj, flags, vec_t)
      print(' ')

      # Projection 
      if constr_config.constraint != 'None':
        if "method" not in constr_config or constr_config.method.op == "proj":
          x, adj = drifted_project (x, adj, constr_config)
          x_mean, adj_mean = drifted_project (x_mean, adj_mean, constr_config)
        elif "method" in constr_config and constr_config.method.op == "prox":
          raise NotImplementedError("prox not implemented yet")
      
      x_gen = x_mean if denoise else x

      return x_gen, (adj_mean if denoise else adj) #, sampling_steps * (n_steps + 1)

  return pc_sampler_sequential