"""A meta PyTorch Lightning model for training and evaluating DISCO models."""

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torch.utils.data
from torch_geometric.data import DataLoader as GraphDataLoader
from pytorch_lightning.utilities import rank_zero_info

from models.gnn_encoder import GNNEncoder
from utils.lr_schedulers import get_schedule_fn
from utils.rddm_diffusion_schedulers import GaussianDiffusion, CategoricalDiffusion

class COMetaModel(pl.LightningModule):
  def __init__(self,
               param_args,
               node_feature_only=False):
    super(COMetaModel, self).__init__()
    self.args = param_args
    self.diffusion_type = self.args.diffusion_type
    self.diffusion_schedule = self.args.diffusion_schedule
    self.diffusion_steps = self.args.diffusion_steps
    self.sparse = self.args.sparse_factor > 0 or node_feature_only

    if self.diffusion_type == 'gaussian':
      out_channels = 1
      self.diffusion = GaussianDiffusion(
          T=self.diffusion_steps, schedule=self.diffusion_schedule)
    elif self.diffusion_type == 'categorical':
      out_channels = 2
      self.diffusion = CategoricalDiffusion(
          T=self.diffusion_steps, schedule=self.diffusion_schedule)
    else:
      raise ValueError(f"Unknown diffusion type {self.diffusion_type}")

    self.model = GNNEncoder(
        n_layers=self.args.n_layers,
        hidden_dim=self.args.hidden_dim,
        out_channels=out_channels,
        aggregation=self.args.aggregation,
        norm=self.args.norm_scheme,
        norm_dim=self.args.simnorm_dim,
        sparse=self.sparse,
        use_activation_checkpoint=self.args.use_activation_checkpoint,
        node_feature_only=node_feature_only,
    )
    self.num_training_steps_cached = None

  def test_epoch_end(self, outputs):
    unmerged_metrics = {}
    for metrics in outputs:
      for k, v in metrics.items():
        if k not in unmerged_metrics:
          unmerged_metrics[k] = []
        unmerged_metrics[k].append(v)

    merged_metrics = {}
    for k, v in unmerged_metrics.items():
      merged_metrics[k] = float(np.mean(v))
    print(merged_metrics)
    self.logger.log_metrics(merged_metrics, step=self.global_step)
    
  def on_predict_epoch_end(self, outputs):
    unmerged_metrics = {}
    for metrics in outputs[0]:
      for k, v in metrics.items():
        if k not in unmerged_metrics:
          unmerged_metrics[k] = []
        unmerged_metrics[k].append(v)

    merged_metrics = {}
    for k, v in unmerged_metrics.items():
      merged_metrics[k] = float(np.mean(v))
    print(merged_metrics)

  def get_total_num_training_steps(self) -> int:
    """Total training steps inferred from datamodule and devices."""
    if self.num_training_steps_cached is not None:
      return self.num_training_steps_cached
    dataset = self.train_dataloader()
    if self.trainer.max_steps and self.trainer.max_steps > 0:
      return self.trainer.max_steps

    dataset_size = (
        self.trainer.limit_train_batches * len(dataset)
        if self.trainer.limit_train_batches != 0
        else len(dataset)
    )

    num_devices = max(1, self.trainer.num_devices)
    effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
    self.num_training_steps_cached = (dataset_size // effective_batch_size) * self.trainer.max_epochs
    return self.num_training_steps_cached

  def configure_optimizers(self):
    rank_zero_info('Parameters: %d' % sum([p.numel() for p in self.model.parameters()]))
    rank_zero_info('Training steps: %d' % self.get_total_num_training_steps())

    if self.args.lr_scheduler == "constant":
      return torch.optim.AdamW(
          self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay)

    else:
      optimizer = torch.optim.AdamW(
          self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
      scheduler = get_schedule_fn(self.args.lr_scheduler, self.get_total_num_training_steps())(optimizer)

      return {
          "optimizer": optimizer,
          "lr_scheduler": {
              "scheduler": scheduler,
              "interval": "step",
          },
      }

  def categorical_posterior(self, target_t, t, x0_pred_prob, xt, x_in):
    """Sample from the categorical posterior for a given time step.
    """
    target_t = target_t.view(xt.shape[0])

    xt = F.one_hot(xt.long(), num_classes=2).float()
    t = t.reshape(-1, 1, 1).to(torch.float32)
    target_t = target_t.reshape(-1, 1, 1).to(torch.float32)
    Q_bar_t = torch.eye(2).unsqueeze(0).repeat(t.shape[0], 1, 1)\
       * (1 - 2 * t) + torch.ones(t.shape[0], 2, 2)
    Q_bar_t = Q_bar_t.to(xt.device)
    Q_bar_t_target = torch.eye(2).unsqueeze(0).repeat(target_t.shape[0], 1, 1)\
       * (1 - 2 * target_t) + torch.ones(target_t.shape[0], 2, 2) * target_t
    Q_bar_t_target = Q_bar_t_target.to(xt.device)

    Q_t = np.linalg.inv(Q_bar_t_target.cpu()) @ Q_bar_t.cpu().numpy()
    Q_t = torch.from_numpy(Q_t).float().to(xt.device)

    # calculate part1
    part1_1 = xt.cpu() - torch.matmul(x_in.cpu(), torch.eye(2)*t.cpu()).squeeze()
    part1_2 = Q_t[0].permute((1, 0)).contiguous().cpu()
    part1 = torch.matmul(part1_1, part1_2).to(xt.device)
    # calculate part2
    part2_1 = x0_pred_prob
    part2_2 = Q_bar_t_target[0]
    part2 = torch.matmul(part2_1, part2_2)
    # calculate part3
    part3 = torch.matmul(x0_pred_prob, Q_bar_t[0]).reshape(1, -1, 1, 2).to(xt.device)
    b, h, w, d = part3.shape
    tmp = part3.flatten(0, 2).unsqueeze(1)
    if self.sparse:
      tmp2 = xt.cpu() - torch.matmul(x_in.cpu(), torch.eye(2)*t.cpu()).squeeze()
      part1 = part1.reshape(b, h, w, d)
      part2 = part2.reshape(b, h, w, d)
    else:
      tmp2 = xt.cpu() - torch.matmul(x_in.cpu(), torch.eye(2)*t.cpu()).squeeze()
    tmp2 = tmp2.reshape(h, w, d).permute(0, 2, 1).contiguous()
    tmp2 = tmp2.to(xt.device)
    part3 = torch.matmul(tmp, tmp2).squeeze(-1)
    part3 = part3.reshape(b, h, w, 1)
    sum_x_t_target_prob = part1 * part2 / part3
    if target_t.all() > 0:
      xt = torch.bernoulli(sum_x_t_target_prob.clamp(0, 1)[..., -1])
    else:
      xt = sum_x_t_target_prob.clamp(min=0)[..., -1]

    if self.sparse:
      xt = xt.reshape(-1)
    return xt

  def gaussian_posterior(self, target_t, t, noise_pred, x_res_pred, xt):
    """Sample (or deterministically denoise) from the Gaussian posterior for a given time step.
       See https://arxiv.org/abs/2306.13720 for details.
    """
    diffusion = self.diffusion
    if target_t is None:
      target_t = t - 1
    else:
      target_t = target_t.view(xt.shape[0])

    # Use DDM posterior
    s = t - target_t
    time = t.reshape(x_res_pred.shape[0], *((1,) * (len(x_res_pred.shape) - 1)))
    s = s.reshape(x_res_pred.shape[0], *((1,) * (len(x_res_pred.shape) - 1)))
    mean = xt + x_res_pred * (time - s) - x_res_pred * time - s / torch.sqrt(time) * noise_pred
    epsilon = torch.randn_like(mean, device=xt.device)
    sigma = torch.sqrt(s * (time - s) / time)
    xt_target = mean + sigma * epsilon
    return xt_target

  def duplicate_edge_index(self, edge_index, num_nodes, device):
    """Duplicate the edge index (in sparse graphs) for parallel sampling."""
    edge_index = edge_index.reshape((2, 1, -1))
    edge_index_indent = torch.arange(0, self.args.parallel_sampling).view(1, -1, 1).to(device)
    edge_index_indent = edge_index_indent * num_nodes
    edge_index = edge_index + edge_index_indent
    edge_index = edge_index.reshape((2, -1))
    return edge_index

  def train_dataloader(self):
    batch_size = self.args.batch_size
    train_dataloader = GraphDataLoader(
        self.train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=self.args.num_workers, pin_memory=True,
        persistent_workers=True, drop_last=True)
    return train_dataloader

  def test_dataloader(self):
    batch_size = 1
    print("Test dataset size:", len(self.test_dataset))
    test_dataloader = GraphDataLoader(self.test_dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

  def val_dataloader(self):
    batch_size = 1
    val_dataset = torch.utils.data.Subset(self.validation_dataset, range(self.args.validation_examples))
    print("Validation dataset size:", len(val_dataset))
    val_dataloader = GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return val_dataloader

  def predict_dataloader(self):
    batch_size = 1
    pred_dataset = torch.utils.data.Subset(self.predict_dataset, range(self.args.prediction_examples))
    print("Inference dataset size: ", len(pred_dataset))
    predict_dataloader = GraphDataLoader(pred_dataset, batch_size=batch_size, shuffle=False)
    return predict_dataloader