"""Lightning module for training the TSP model."""

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from pytorch_lightning.utilities import rank_zero_info
from scipy.sparse import coo_matrix

from co_datasets.tsp_graph_dataset import TSPGraphDataset
from pl_meta_model_rddm import COMetaModel
from utils.tsp_utils import TSPEvaluator, batched_two_opt_torch, merge_tours
from utils.cython_farthest_insertion.farthest_insertion import farthest_insertion

import time
import multiprocessing
import itertools
# import matplotlib.pyplot as plt

class TSPModel(COMetaModel):
  def __init__(self,
               param_args=None):
    super(TSPModel, self).__init__(param_args=param_args, node_feature_only=False)

    self.train_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.training_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.test_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.test_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.validation_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.validation_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.predict_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.predict_split),
        sparse_factor=self.args.sparse_factor,
    )

    if self.sparse:
      self.all_edge_ranks = []
      self.best_sub_heatmaps = []

  def forward(self, x, adj, t, edge_index):
    return self.model(x, t, adj, edge_index)

  def categorical_training_step(self, batch, batch_idx):
    edge_index = None
    if self.sparse:
      _, graph_data, point_indicator, edge_indicator, _ = batch
      route_edge_flags = graph_data.edge_attr
      points = graph_data.x
      edge_index = graph_data.edge_index
      num_edges = edge_index.shape[1]
      batch_size = point_indicator.shape[0]
      gt_adjmat = route_edge_flags.reshape((batch_size, num_edges // batch_size)).float()
    else:
      _, points, gt_adjmat, _ = batch
      batch_size = points.shape[0]
    
    x_in = self.generate_x_in(points,
                              edge_index=edge_index,
                              batchsize=batch_size)
    x_in = (x_in + 1) // 2
    if self.sparse:
      x_in = x_in.reshape((batch_size, num_edges // batch_size))
    x_res = x_in - gt_adjmat
    
    adj_matrix_onehot = F.one_hot(gt_adjmat.long(), num_classes=2).float()
    x_in_onehot = F.one_hot(x_in.long(), num_classes=2).float()
    if self.sparse:
      adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)
      x_in_onehot = x_in_onehot.unsqueeze(1)

    # sample from diffusion
    t = torch.rand(size=(gt_adjmat.shape[0], ), device=gt_adjmat.device) * (1 - self.args.eps) \
        + self.args.eps
    xt = self.diffusion.sample(adj_matrix_onehot, x_in_onehot, t)
    xt = xt * 2 - 1
    xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

    if self.sparse:
      t = t.reshape(-1, 1).repeat(1, gt_adjmat.shape[1]).reshape(-1)
      xt = xt.reshape(-1)
      gt_adjmat = gt_adjmat.reshape(-1)
      points = points.reshape(-1, 2)
      edge_index = edge_index.float().to(gt_adjmat.device).reshape(2, -1)
    else:
      t = t.float().view(gt_adjmat.shape[0])
    
    # Denoise
    x_res_pred  = self.forward(
        points.float().to(gt_adjmat.device),
        xt.float().to(gt_adjmat.device),
        t.float().to(gt_adjmat.device),
        edge_index,
    )

    # Compute loss
    if self.sparse:
      x0_pred = x_in_onehot.reshape(-1, 2) - x_res_pred
    else:
      x0_pred = x_in_onehot.permute(0, 3, 1, 2) - x_res_pred
    loss_func = nn.CrossEntropyLoss()
    loss = loss_func(x0_pred, gt_adjmat.long())

    # log dict
    self.log("train/loss", loss)

    return loss 

  def gaussian_training_step(self, batch, batch_idx):
    edge_index = None
    if self.sparse:
      # Implement Gaussian diffusion with sparse graphs
      _, graph_data, point_indicator, edge_indicator, _ = batch
      route_edge_flags = graph_data.edge_attr
      points = graph_data.x
      edge_index = graph_data.edge_index
      num_edges = edge_index.shape[1]
      batch_size = point_indicator.shape[0]
      gt_adjmat = route_edge_flags.reshape((batch_size, num_edges // batch_size)).float()
    else:
      _, points, gt_adjmat, _ = batch
      edge_index = None
      batch_size = points.shape[0]
    gt_heatmap = gt_adjmat * 2 - 1
    gt_heatmap = gt_heatmap * (1.0 + 0.05 * torch.rand_like(gt_heatmap))

    x_in = self.generate_x_in(points, edge_index=edge_index, batchsize=batch_size)
    if self.sparse:
      x_in = x_in.reshape((batch_size, num_edges // batch_size))
    # Sample from diffusion
    t = torch.rand(size=(gt_heatmap.shape[0], ), device=gt_heatmap.device) * (1 - self.args.eps) \
     + self.args.eps

    x_res = x_in - gt_heatmap
    epsilon = torch.randn_like(gt_heatmap)
    xt = self.diffusion.sample(gt_heatmap, epsilon, t, x_res) # x_t = x_in + epsilon
    # t = torch.from_numpy(t).float().view(gt_heatmap.shape[0])

    # Denoise (using DDM)
    points = points.float().to(gt_heatmap.device)
    xt = xt.float().to(gt_heatmap.device)
    t = t.to(gt_heatmap.device)
    if self.sparse:
      t = t.reshape(-1, 1).repeat(1, gt_adjmat.shape[1]).reshape(-1)
      xt = xt.reshape(-1)
      gt_heatmap = gt_heatmap.reshape(-1)
      points = points.reshape(-1, 2)
      edge_index = edge_index.float().to(gt_heatmap.device).reshape(2, -1)
    # use two separate networks to predict x_res and epsilon
    x_res_pred, noise_pred = self.forward(points, xt, t, edge_index)
    if self.sparse:
      x_res_pred = x_res_pred.reshape((batch_size, num_edges // batch_size))
      noise_pred = noise_pred.reshape((batch_size, num_edges // batch_size))
    
    t_tmp = t.reshape(-1, 1, 1)

    # Compute loss
    noise_loss = F.mse_loss(noise_pred, epsilon.float())
    x_res_loss = F.mse_loss(x_res_pred, x_res.float())
    # only res_loss
    # loss = x_res_loss
    simple_weight1 = (t_tmp ** 2 + (t_tmp - 1) ** 2 + 1) / (t_tmp ** 2 + 1)
    simple_weight2 = (t_tmp ** 2 + (t_tmp - 1) ** 2 + 1) / (1 + (t_tmp - 1) ** 2)
    # loss = 0.5 * x_res_loss + 0.5 * noise_loss
    loss = simple_weight1 * x_res_loss + simple_weight2 * noise_loss
    loss = loss.mean()
    self.log("train/noise loss", noise_loss)
    self.log("train/x_res loss", x_res_loss)
    self.log("train/weight1", simple_weight1[0])
    self.log("train/weight2", simple_weight2[0])
    self.log("train/loss", loss)

    return loss

  def training_step(self, batch, batch_idx):
    if self.diffusion_type == 'gaussian':
      return self.gaussian_training_step(batch, batch_idx)
    elif self.diffusion_type == 'categorical':
      return self.categorical_training_step(batch, batch_idx)
    else:
      raise ValueError(f"Unknown diffusion type {self.diffusion_type}")

  def gaussian_denoise_step(self, points, xt, x_in, t, device, edge_index=None, target_t=None):
    with torch.no_grad():
      t = t.view(xt.shape[0])
      points = points.float().to(device)
      xt = xt.float().to(device)
      x_res_pred, noise_pred = self.forward(
          points,
          xt,
          t,
          edge_index.long().to(device) if edge_index is not None else None,
      )

      xt = self.gaussian_posterior(target_t.cpu(), t.cpu(), noise_pred.cpu(), x_res_pred.cpu(), xt.cpu())

      return xt.to(device)

  def categorical_denoise_step(self, points, xt, x_in_onehot, t, device, edge_index=None, target_t=None):
    with torch.no_grad():
      t = t.view(xt.shape[0]).to(device)
      points = points.float().to(device)
      xt = xt.float().to(device)
      x_res_pred = self.forward(
          points,
          xt,
          t,
          edge_index.long().to(device) if edge_index is not None else None,
      )
      if not self.sparse:
        x_res_pred = x_res_pred.permute(0, 2, 3, 1)
        x_res_pred = x_res_pred.clamp(-1, 1)
        x0_pred_onehot = x_in_onehot - x_res_pred
        x0_pred_prob = x0_pred_onehot.reshape((1, xt.shape[1], -1, 2)).softmax(dim=-1)
        x_in_pred_prob = x_in_onehot.reshape((1, xt.shape[1], -1, 2)).softmax(dim=-1)
      else:
        x_res_pred = x_res_pred.reshape((1, xt.shape[0], -1, 2))
        x_res_pred = x_res_pred.clamp(-1, 1)
        x_in_onehot = x_in_onehot.reshape((1, xt.shape[0], -1, 2))
        x0_pred_onehot = x_in_onehot - x_res_pred
        x0_pred_prob = x0_pred_onehot.reshape((1, xt.shape[0], -1, 2)).softmax(dim=-1)
        x_in_pred_prob = x_in_onehot.reshape((1, xt.shape[0], -1, 2)).softmax(dim=-1)

      xt = self.categorical_posterior(target_t.cpu(), t.cpu(), x0_pred_prob, xt, x_in_pred_prob)
      return xt

  def generate_x_in(self, points, edge_index=None, is_subgraph=False, batchsize=1):
    if self.args.degraded_solution == "in_order":
      if is_subgraph and self.args.use_multi_processing:
        tour = np.arange(0, self.args.sub_graph_size)
        tour = np.append(tour, 0)
        tour_edges = np.zeros(self.args.sub_graph_size, dtype=np.int64)
        tour_edges[tour[:-1]] = tour[1:]
        tour_edges = torch.from_numpy(tour_edges).to(edge_index.device)
        degree = min(self.args.sparse_factor, self.args.sub_graph_size)
        tour_edges = tour_edges.reshape((-1, 1)).repeat(1, degree).reshape(-1)
        tour_edges = tour_edges.repeat(1, len(points)).squeeze()
        edge_index_1 = edge_index[1]
        x_in = torch.eq(edge_index_1, tour_edges).reshape(-1, 1)
        x_in = x_in.to(edge_index.device)
      elif self.sparse:
        tour = np.arange(0, points.shape[0])
        tour = np.append(tour, 0)
        tour_edges = np.zeros(points.shape[0], dtype=np.int64)
        tour_edges[tour[:-1]] = tour[1:]
        tour_edges = torch.from_numpy(tour_edges).to(points.device)
        if is_subgraph:
          degree = min(self.args.sparse_factor, self.args.sub_graph_size)
          tour_edges = tour_edges.reshape((-1, 1)).repeat(1, degree).reshape(-1)
        else:
          tour_edges = tour_edges.reshape((-1, 1)).repeat(1, self.args.sparse_factor).reshape(-1)
        edge_index_1 = edge_index[1]
        x_in = torch.eq(edge_index_1, tour_edges).reshape(-1, 1)
        x_in = x_in.to(points.device)
      else:
        tour = np.arange(0, points.shape[1])
        tour = np.append(tour, 0)
        x_0 = np.zeros((points.shape[1], points.shape[1]))
        for i in range(tour.shape[0]-1):
          x_0[tour[i], tour[i+1]] = 1
        x_in = x_0
        x_in = torch.from_numpy(x_in).float()
        x_in = x_in.unsqueeze(0)
        x_in = x_in.repeat(points.shape[0], 1, 1)
        x_in = x_in.to(points.device)
    elif self.args.degraded_solution == "farthest_insertion":
      if self.sparse:
        if is_subgraph:
          points = np.array([elem.cpu().numpy() for elem in points])
          points = torch.from_numpy(points).to(edge_index.device)
        points = points.reshape(batchsize, -1, 2)
        tours = farthest_insertion(points.cpu().numpy())
        tours = np.array(tours)
        graph_size = points.shape[1]

        if is_subgraph:
          degree = min(self.args.sparse_factor, self.args.sub_graph_size)
          tour_edges = np.zeros((batchsize, self.args.sub_graph_size), dtype=np.int64)
          for i in range(batchsize):
            tour_edges[i][tours[i][:-1]] = tours[i][1:]
          tour_edges = torch.from_numpy(tour_edges).to(edge_index.device)
          tour_edges = tour_edges.reshape((-1, 1)).repeat(1, degree).reshape(-1)
        else:
          tour_edges = np.zeros((batchsize, graph_size), dtype=np.int64)
          for i in range(batchsize):
            # tours[i] += i * graph_size
            tour_edges[i][tours[i][:-1]] = tours[i][1:] + i * graph_size
          tour_edges = torch.from_numpy(tour_edges).to(points.device)
          tour_edges = tour_edges.reshape((-1, 1)).repeat(1, self.args.sparse_factor).reshape(-1)
        edge_index_1 = edge_index[1]
        x_in = torch.eq(edge_index_1, tour_edges).reshape(-1, 1)
        x_in = x_in.to(points.device)
      else:
        tours = farthest_insertion(points.cpu().numpy())
        tours = np.array(tours)
        x_0 = np.zeros((points.shape[1], points.shape[1]))
        x_0 = torch.from_numpy(x_0).float()
        x_0 = x_0.unsqueeze(0)
        x_0 = x_0.repeat(points.shape[0], 1, 1)
        
        for i in range(points.shape[0]):
          for j in range(tours.shape[1]-1):
            x_0[i][tours[i][j], tours[i][j+1]] = 1
        x_in = x_0
        x_in = x_in.to(points.device)


    x_in = x_in * 2 - 1
    return x_in

  def test_step(self, batch, batch_idx, split='test'):
    start_time = time.time()
    edge_index = None
    np_edge_index = None
    device = batch[-1].device
    if not self.sparse:
      real_batch_idx, points, adj_matrix, gt_tour = batch
      np_points = points.cpu().numpy()[0]
      np_gt_tour = gt_tour.cpu().numpy()[0]
    else:
      real_batch_idx, graph_data, point_indicator, edge_indicator, gt_tour = batch
      route_edge_flags = graph_data.edge_attr
      points = graph_data.x
      edge_index = graph_data.edge_index
      num_edges = edge_index.shape[1]
      batch_size = point_indicator.shape[0]
      adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
      points = points.reshape((-1, 2))
      edge_index = edge_index.reshape((2, -1))
      np_points = points.cpu().numpy()
      np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
      np_edge_index = edge_index.cpu().numpy()

    stacked_tours = []
    ns, merge_iterations = 0, 0

    xt = torch.ones_like(adj_matrix.float())
    if self.sparse:
      xt = xt.squeeze()
    if self.args.diffusion_type == "categorical":
      adj_mat = xt.float().cpu().detach().numpy() + 1e-6
    else:
      adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5

    if self.args.save_numpy_heatmap:
      if self.sparse:
        # convert to dense adj_mat
        tsp_size = points.shape[0]
        adj_mat_dense = coo_matrix((adj_mat, (edge_index[0].cpu().numpy(), edge_index[1].cpu().numpy())), shape=(tsp_size, tsp_size)).todense()
      else:
        adj_mat_dense = adj_mat
      self.run_save_numpy_heatmap(adj_mat_dense, np_points, real_batch_idx, split)

    tours, merge_iterations = merge_tours(
        adj_mat, np_points, np_edge_index,
        sparse_graph=self.sparse,
        parallel_sampling=self.args.parallel_sampling,
    )

    # Refine using 2-opt
    solved_tours, ns = batched_two_opt_torch(
        np_points.astype("float64"), np.array(tours).astype('int64'),
        max_iterations=self.args.two_opt_iterations, device=device)
    # solved_tours = tours
    stacked_tours.append(solved_tours)

    solved_tours = np.concatenate(stacked_tours, axis=0)

    tsp_solver = TSPEvaluator(np_points)
    gt_cost = tsp_solver.evaluate(np_gt_tour)

    total_sampling = self.args.parallel_sampling * self.args.sequential_sampling
    all_solved_costs = [tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)]
    best_solved_cost = np.min(all_solved_costs)
    gap = (best_solved_cost - gt_cost) / gt_cost * 100

    # cal tour cost w/o 2-opt
    best_idx = np.argmin(all_solved_costs)
    greedy_cost = tsp_solver.evaluate(tours[best_idx])
    greedy_gap = (greedy_cost - gt_cost) / gt_cost * 100

    metrics = {
        f"{split}/gt_cost": gt_cost,
        f"{split}/2opt_iterations": ns,
        f"{split}/merge_iterations": merge_iterations,
        f"{split}/cost_wo_2opt": greedy_cost,
        f"{split}/gap_wo_2opt(%)": greedy_gap,
    }
    end_time = time.time()
    # print("per-instance run time: ", end_time-start_time)
    for k, v in metrics.items():
      self.log(k, v, on_epoch=True, sync_dist=True)
    self.log(f"{split}/solved_cost", best_solved_cost, prog_bar=True, on_epoch=True, sync_dist=True)
    self.log(f"{split}/gap(%)", gap, prog_bar=True, on_epoch=True, sync_dist=True)

    return metrics

  def run_save_numpy_heatmap(self, adj_mat, np_points, real_batch_idx, split):
    # if self.args.parallel_sampling > 1 or self.args.sequential_sampling > 1:
    #   raise NotImplementedError("Save numpy heatmap only support single sampling")
    exp_save_dir = os.path.join(self.logger.save_dir, self.logger.name, str(self.logger.version))
    heatmap_path = os.path.join(exp_save_dir, 'numpy_heatmap')
    rank_zero_info(f"Saving heatmap to {heatmap_path}")
    os.makedirs(heatmap_path, exist_ok=True)
    real_batch_idx = real_batch_idx.cpu().numpy().reshape(-1)[0]
    np.save(os.path.join(heatmap_path, f"{split}-heatmap-{real_batch_idx}.npy"), adj_mat)
    np.save(os.path.join(heatmap_path, f"{split}-points-{real_batch_idx}.npy"), np_points)

  def validation_step(self, batch, batch_idx):
    return self.test_step(batch, batch_idx, split='val')
