import math
import numpy as np
import torch
import gpytorch
import random
from matplotlib import pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import sys
from itertools import zip_longest
#sys.path.append("../directionalvi")
sys.path.append("utils")
from RBFKernelDirectionalGrad import RBFKernelDirectionalGrad #.RBFKernelDirectionalGrad
#from DirectionalGradVariationalStrategy import DirectionalGradVariationalStrategy #.DirectionalGradVariationalStrategy
from monotonicdgvs import DirectionalGradVariationalStrategy #.DirectionalGradVariationalStrategy
from CiqDirectionalGradVariationalStrategy import CiqDirectionalGradVariationalStrategy #.DirectionalGradVariationalStrategy
from utils.count_params import count_params
from gpytorch.variational import VariationalStrategy
from virtual_probit import compute_L_virt
try: # import wandb if watch model on weights&biases
  import wandb
except:
  pass
from itertools import cycle



class GPModel(gpytorch.models.ApproximateGP):
    def __init__(self,inducing_points,inducing_directions,dim,**kwargs):

        self.num_inducing   = len(inducing_points)
        self.num_directions = int(len(inducing_directions)/self.num_inducing) # num directions per point
        num_directional_derivs = self.num_directions*self.num_inducing

        # variational distribution q(u,g)
        # variational_distribution = gpytorch.variational.DeltaVariationalDistribution(
        #     num_inducing + num_directional_derivs)
        if "variational_distribution" in kwargs and kwargs["variational_distribution"] == "NGD":
          variational_distribution = gpytorch.variational.NaturalVariationalDistribution(
            self.num_inducing + num_directional_derivs)
        else:
          variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            self.num_inducing + num_directional_derivs)


        # variational strategy q(f)
        if "variational_strategy" in kwargs and kwargs["variational_strategy"] == "CIQ":
          variational_strategy = CiqDirectionalGradVariationalStrategy(self,
            inducing_points, inducing_directions,variational_distribution, learn_inducing_locations=True)
        else:
          variational_strategy = DirectionalGradVariationalStrategy(self,
            inducing_points,inducing_directions,variational_distribution, learn_inducing_locations=True)
        super(GPModel, self).__init__(variational_strategy)

        # set the mean and covariance
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(RBFKernelDirectionalGrad())

    def forward(self, x, **params):
        mean_x  = self.mean_module(x)
        covar_x = self.covar_module(x, **params)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
    
def sample_direction_indices(dim: int, minibatch_dim: int) -> torch.Tensor:
    """Randomly choose p distinct derivative dims from {0..dim-1}, sorted.
    input
    dim: int, problem dimension
    minibatch_dim: int, number of derivative dimensions to sample
    """
 
    assert 1 <= minibatch_dim <= dim
    idx = random.sample(range(dim), minibatch_dim)
    idx.sort()
    return torch.tensor(idx, dtype=torch.long)

def indices_to_directions(indices, dim, device=None, dtype=torch.float32):
    """
    Convert derivative direction indices (e.g. [0,1]) 
    into canonical basis vectors (e.g. [[1,0],[0,1]]).

    Args:
        indices: 1D tensor or list of ints (indices of active directions)
        dim: total problem dimension
        device: torch device (e.g. 'cuda:0' or 'cpu')
        dtype: torch dtype

    Returns:
        directions: tensor of shape (len(indices), dim)
    """
    indices = torch.as_tensor(indices, dtype=torch.long, device=device)
    directions = torch.zeros((len(indices), dim), device=device, dtype=dtype)
    directions[torch.arange(len(indices)), indices] = 1.0
    return directions

def select_virtual_directions_by_indices(Xvirt, Vvirt, selected_indices):
    """
    Filter virtual dataset to keep only selected derivative dimensions active.
    - Xvirt: (J, D)
    - Vvirt: (J, D), typically all ones
    - selected_indices: 1D tensor of indices to keep (e.g., [0, 2])

    Returns:
        (Xvirt, Vvirt_selected)
        where Vvirt_selected has 1's only on chosen dims, else 0.
    """
    D = Vvirt.shape[1]
    mask = torch.zeros(D, device=Vvirt.device, dtype=Vvirt.dtype)
    mask[selected_indices] = 1.0
    Vvirt_selected = Vvirt * mask  # broadcasted elementwise multiply
    return Xvirt, Vvirt_selected


def train_gp(train_dataset, virtual_dataset, num_inducing=128,
  num_directions=2,minibatch_size=1,minibatch_dim =2,num_epochs=1,
  learning_rate_hypers=0.01,learning_rate_ngd=0.1,
  inducing_data_initialization=True,
  use_ngd=False,
  use_ciq=False,
  lr_sched=None,
  mll_type="ELBO",
  num_contour_quadrature=15,
  watch_model=False,gamma=0.1,
  verbose=True,
  fixed_inducing_locations=None,
  nu=1,
  mu=1e-2,
  gh_nodes=12,
  virt_mc_samples=15, 
  **args):
  """Train a Derivative GP with the Directional Derivative
  Variational Inference method

  train_dataset: torch Dataset
  num_inducing: int, number of inducing points
  num_directions: int, number of inducing directions (per inducing point)
  minbatch_size: int, number of data points in a minibatch
  minibatch_dim: int, number of derivative per point in minibatch training
                 WARNING: This must equal num_directions until we complete
                 the PR in GpyTorch.
  num_epochs: int, number of epochs
  inducing_data_initialization: initialize the inducing points as a set of 
      data points. If False, the inducing points are generated on the unit cube
      uniformly, U[0,1]^d.
  learning_rate_hypers, float: initial learning rate for the hyper optimizer
  learning_rate_ngd, float: initial learning rate for the variational optimizer
  use_ngd, bool: use NGD
  use_ciq, bool: use CIQ
  lr_sched, function handle: used in the torch LambdaLR learning rate scheduler. At
      each iteration the initial learning rate is multiplied by the result of 
      this function. The function input is the epoch, i.e. lr_sched(epoch). 
      The function should return a single number. If lr_sched is left as None, 
      the learning rate will be held constant.
  """
  assert num_directions == minibatch_dim

  # set up the data loader
  train_loader  = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
  virtual_loader = DataLoader(virtual_dataset, batch_size=minibatch_size, shuffle=True)
  dim = len(train_dataset[0][0])
  n_samples = len(train_dataset)
  num_data = (dim+1)*n_samples

  if inducing_data_initialization is True:
    # initialize inducing points and directions from data
    inducing_points = torch.zeros(num_inducing,dim)
    # canonical directions
    inducing_directions = torch.eye(dim)[:num_directions] 
    inducing_directions = inducing_directions.repeat(num_inducing,1)
    for ii in range(num_inducing):
      inducing_points[ii] = train_dataset[ii][0]
      #inducing_directions[ii*num_directions] = train_dataset[ii][1][1:] # gradient
  else:
    # random points on the unit cube
    inducing_points     = torch.rand(num_inducing, dim)
    #inducing_directions = torch.rand(num_inducing*num_directions,dim)
    #inducing_directions = (inducing_directions.T/torch.norm(inducing_directions,dim=1)).T
    inducing_directions = torch.eye(dim)[:num_directions] # canonical directions
    inducing_directions = inducing_directions.repeat(num_inducing,1)
  if torch.cuda.is_available():
    inducing_points = inducing_points.cuda()
    inducing_directions = inducing_directions.cuda()


  # initialize model
  if use_ciq:
    gpytorch.settings.num_contour_quadrature(num_contour_quadrature)
    model = GPModel(inducing_points,inducing_directions,dim, variational_distribution="NGD",variational_strategy="CIQ")
  elif use_ngd:
    model = GPModel(inducing_points,inducing_directions,dim, variational_distribution="NGD")
  else:
    model = GPModel(inducing_points,inducing_directions,dim)
  likelihood = gpytorch.likelihoods.GaussianLikelihood()
  if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()
  if watch_model:
    wandb.watch(model)
  # training mode
  model.train()
  likelihood.train()

  if verbose:
      param_total_dim = count_params(model,likelihood)

  # optimizers
  if use_ngd or use_ciq:
    variational_optimizer = gpytorch.optim.NGD(model.variational_parameters(), num_data=num_data, lr=learning_rate_ngd)
    hyperparameter_optimizer = torch.optim.Adam([
        {'params': model.hyperparameters()},
        {'params': likelihood.parameters()},
    ], lr=learning_rate_hypers)
  else:
    variational_optimizer = torch.optim.Adam([
        {'params': model.variational_parameters()},
    ], lr=learning_rate_hypers)
    hyperparameter_optimizer = torch.optim.Adam([
        {'params': model.hyperparameters()},
        {'params': likelihood.parameters()},
    ], lr=learning_rate_hypers)
      
  # learning rate scheduler
  #lambda1 = lambda epoch: 1.0/(1 + epoch)
  if lr_sched == "step_lr":
    num_batches = int(np.ceil(n_samples/minibatch_size))
    milestones = [int(num_epochs*num_batches/3), int(2*num_epochs*num_batches/3)]
    hyperparameter_scheduler = torch.optim.lr_scheduler.MultiStepLR(hyperparameter_optimizer, milestones, gamma=gamma)
    variational_scheduler = torch.optim.lr_scheduler.MultiStepLR(variational_optimizer, milestones, gamma=gamma)
  elif lr_sched is None:
    lr_sched = lambda epoch: 1.0
    hyperparameter_scheduler = torch.optim.lr_scheduler.LambdaLR(hyperparameter_optimizer, lr_lambda=lr_sched)
    variational_scheduler = torch.optim.lr_scheduler.LambdaLR(variational_optimizer, lr_lambda=lr_sched)
  else:
    hyperparameter_scheduler = torch.optim.lr_scheduler.LambdaLR(hyperparameter_optimizer, lr_lambda=lr_sched)
    variational_scheduler = torch.optim.lr_scheduler.LambdaLR(variational_optimizer, lr_lambda=lr_sched)
  # mll
  if mll_type=="ELBO":
    mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=num_data)
  elif mll_type=="PLL": 
    mll = gpytorch.mlls.PredictiveLogLikelihood(likelihood, model, num_data=num_data)

  # train
  S_r = len(train_loader)
  S_v = len(virtual_loader)
  mu = mu*S_v/S_r
  epochs_iter = range(num_epochs)
  total_step=0
  longer_loader  = train_loader if len(train_loader) >= len(virtual_loader) else virtual_loader
  shorter_loader = virtual_loader if longer_loader is train_loader else train_loader

  real_first = (longer_loader is train_loader)  
  shorter_iter = cycle(shorter_loader) 
  for i in epochs_iter:
    # iterator for minibatches
    for step, longer_batch in enumerate(longer_loader):
        
      rb = longer_batch if real_first else next(shorter_iter)
      vb = next(shorter_iter) if real_first else longer_batch


      # select only the chosen derivative dimensions
      selected_idx = sample_direction_indices(dim, minibatch_dim)


      if rb is not None:
          x_batch, y_batch = rb

          if torch.cuda.is_available():
              x_batch = x_batch.cuda()
              y_batch = y_batch.cuda()
          derivative_directions = indices_to_directions(selected_idx, dim, device=x_batch.device)

          kwargs = {}
          # repeat the derivative directions for each point in x_batch
          kwargs['derivative_directions'] = derivative_directions.repeat(x_batch.size(0), 1)
          # select random columns of y_batch to train on
          variational_optimizer.zero_grad()
          hyperparameter_optimizer.zero_grad()
          mvn, mvn_D = model(x_batch, **kwargs)
          output = likelihood(mvn)
          loss = -mll(output, y_batch)
      else:
          # If there’s no real batch in this step (only happens when virtual_loader is longer),
          # skip the step or start loss at zero and ONLY add virtuals if you want.
          loss = torch.tensor(0.0, device=inducing_points.device)

      if vb is not None:
        Xv, Vv = vb
        Xv, Vv = select_virtual_directions_by_indices(Xv, Vv, selected_idx)
        if torch.cuda.is_available():
            Xv = Xv.cuda(); Vv = Vv.cuda()
        deriv_dirs_virt = indices_to_directions(selected_idx, dim, device=Xv.device)
        deriv_dirs_virt = deriv_dirs_virt.repeat(Xv.size(0), 1)
        Lvirt = compute_L_virt(model, Xv, Vv,
                              nu=nu,
                              derivative_directions=deriv_dirs_virt,
                              gh_nodes=gh_nodes,       # ignored if mc_samples>0
                              mc_samples=virt_mc_samples)
        loss = loss - mu*Lvirt

      # If neither rb nor vb exists (shouldn’t happen), continue
      if (rb is None) and (vb is None):
          continue

      if watch_model:
        wandb.log({"loss": loss.item()})
      loss.backward()
      # step optimizers and learning rate schedulers
      variational_optimizer.step()
      variational_scheduler.step()
      hyperparameter_optimizer.step()
      hyperparameter_scheduler.step()
      if total_step % 50 == 0 and verbose:
          means = output.mean[::num_directions+1]
          stds  = output.variance.sqrt()[::num_directions+1]
          
          nll   = -torch.distributions.Normal(means, stds).log_prob(y_batch[::num_directions+1]).mean()
          print(f"Epoch: {i}; total_step: {total_step}, loss: {loss.item()}, nll: {nll}")
          sys.stdout.flush()

      total_step +=1
     
  if verbose:
    print(f"Done! loss: {loss.item()}")

    print("\nDone Training!")
  return model,likelihood


def eval_gp(test_dataset,model,likelihood,
            mll_type="ELBO",num_directions=1,minibatch_size=1,minibatch_dim =1):
  
  assert num_directions == minibatch_dim

  dim = len(test_dataset[0][0])
  n_test = len(test_dataset)
  test_loader = DataLoader(test_dataset, batch_size=minibatch_size, shuffle=False)
  
  model.eval()
  likelihood.eval()
  
  kwargs = {}
  means = torch.tensor([0.])
  variances = torch.tensor([0.])
  means_list = []
  vars_list  = []
  for x_batch, y_batch in test_loader:
    if torch.cuda.is_available():
      x_batch = x_batch.cuda()
      y_batch = y_batch.cuda()
    # redo derivative directions b/c batch size is not consistent
    derivative_directions = torch.eye(dim)[:num_directions]
    derivative_directions = derivative_directions.repeat(len(x_batch),1)
    kwargs['derivative_directions'] = derivative_directions
    # predict
    mvn, mvn_D = model(x_batch, **kwargs)
    preds = likelihood(mvn)
    means_list.append(preds.mean)
    vars_list.append(preds.variance)

  means = torch.cat(means_list, dim=0)
  variances = torch.cat(vars_list, dim=0)
  return means, variances