"""Lightning module for training the TSP model."""

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from pytorch_lightning.utilities import rank_zero_info
from scipy.sparse import coo_matrix

from co_datasets.tsp_graph_dataset import TSPGraphDataset
from pl_meta_model_rddm import COMetaModel
from utils.tsp_utils import TSPLibEvaluator, batched_two_opt_torch, merge_tours

import copy
import time

class TSPModel(COMetaModel):
  def __init__(self,
               param_args=None):
    super(TSPModel, self).__init__(param_args=param_args, node_feature_only=False)

    self.train_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.training_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.test_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.test_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.validation_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.validation_split),
        sparse_factor=self.args.sparse_factor,
    )

    self.predict_dataset = TSPGraphDataset(
        data_file=os.path.join(self.args.storage_path, self.args.predict_split),
        sparse_factor=self.args.sparse_factor,
    )

    if self.sparse:
      self.all_edge_ranks = []
      self.best_sub_heatmaps = []

  def forward(self, x, adj, t, edge_index):
    return self.model(x, t, adj, edge_index)

  def training_step(self, batch, batch_idx):
    exit(0)

  def gaussian_denoise_step(self, points, xt, x_in, t, device, edge_index=None, target_t=None):
    with torch.no_grad():
      t = t.view(xt.shape[0])
      points = points.float().to(device)
      xt = xt.float().to(device)
      x_res_pred, noise_pred = self.forward(
          points,
          xt,
          t,
          edge_index.long().to(device) if edge_index is not None else None,
      )

      xt = self.gaussian_posterior(target_t.cpu(), t.cpu(), noise_pred.cpu(), x_res_pred.cpu(), xt.cpu())

      return xt.to(device)

  def generate_x_in(self, points, edge_index=None):
    if self.sparse:
      tour = np.arange(0, points.shape[0])
      tour = np.append(tour, 0)
      tour_edges = np.zeros(points.shape[0], dtype=np.int64)
      tour_edges[tour[:-1]] = tour[1:]
      tour_edges = torch.from_numpy(tour_edges).to(points.device)
      tour_edges = tour_edges.reshape((-1, 1)).repeat(1, self.args.sparse_factor).reshape(-1)
      edge_index_1 = edge_index[1]
      x_in = torch.eq(edge_index_1, tour_edges).reshape(-1, 1)
      x_in = x_in.to(points.device)
    else:
      tour = np.arange(0, points.shape[1])
      tour = np.append(tour, 0)
      x_0 = np.zeros((points.shape[1], points.shape[1]))
      for i in range(tour.shape[0]-1):
        x_0[tour[i], tour[i+1]] = 1
      x_in = x_0
      x_in = torch.from_numpy(x_in).float()
      x_in = x_in.unsqueeze(0)
      x_in = x_in.repeat(points.shape[0], 1, 1)
      x_in = x_in.to(points.device)

    x_in = x_in * 2 - 1
    return x_in

  def test_step(self, batch, batch_idx, split='test'):
    start_time = time.time()
    edge_index = None
    np_edge_index = None
    device = batch[1].device
    if not self.sparse:
      real_batch_idx, points, adj_matrix, gt_tour, solution, dataset_name = batch
      np_points = points.cpu().numpy()[0]
      np_gt_tour = gt_tour.cpu().numpy()[0]
      solution = solution.cpu().numpy()[0]
    else:
      real_batch_idx, graph_data, point_indicator, edge_indicator, gt_tour = batch
      route_edge_flags = graph_data.edge_attr
      points = graph_data.x
      edge_index = graph_data.edge_index
      num_edges = edge_index.shape[1]
      batch_size = point_indicator.shape[0]
      adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
      points = points.reshape((-1, 2))
      edge_index = edge_index.reshape((2, -1))
      np_points = points.cpu().numpy()
      np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
      np_edge_index = edge_index.cpu().numpy()

    stacked_tours = []
    ns, merge_iterations = 0, 0

    # normalization
    max_coord = int(points.max().cpu())
    points = points / max_coord
    np_points_ori = copy.deepcopy(np_points)
    np_points = np_points / max_coord

    x_in_ori = self.generate_x_in(points, edge_index)
    if self.sparse:
      x_in_ori = x_in_ori.reshape((batch_size, num_edges // batch_size))
    if self.args.parallel_sampling > 1:
      if not self.sparse:
        points = points.repeat(self.args.parallel_sampling, 1, 1)
      else:
        points = points.repeat(self.args.parallel_sampling, 1) # points: 4 * 50, 2
        edge_index = self.duplicate_edge_index(edge_index, np_points.shape[0], device) # 2, 4*500

    self.args.sequential_sampling = 50
    for _ in range(self.args.sequential_sampling):
      x_in = x_in_ori.clone()
      epsilon = torch.randn_like(adj_matrix.float())
      if self.args.parallel_sampling > 1:
        if self.sparse:
          epsilon = epsilon.repeat(self.args.parallel_sampling, 1)
          x_in = x_in.repeat(self.args.parallel_sampling, 1)
        else:
          epsilon = epsilon.repeat(self.args.parallel_sampling, 1, 1)
          x_in = x_in.repeat(self.args.parallel_sampling, 1, 1)
      xt = x_in + epsilon

      xt.requires_grad = True

      if self.sparse:
        xt = xt.reshape(-1)
        x_in = x_in.reshape(-1)

      steps = self.args.inference_diffusion_steps
      # Diffusion iterations
      cur_time = torch.ones((xt.shape[0],), device=xt.device)
      step = 1.0 / steps
      for i in range(steps):
        s = torch.full((xt.shape[0], ), step, device=xt.device)
        if i == steps-1:
          s = cur_time
        t1 = cur_time
        t2 = cur_time - s
        xt = self.gaussian_denoise_step(
            points, xt, x_in, t1, device, edge_index, target_t=t2)
        cur_time = cur_time - s


      if self.sparse:
        xt = xt.squeeze()
      adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5

      tours, merge_iterations = merge_tours(
          adj_mat, np_points, np_edge_index,
          sparse_graph=self.sparse,
          parallel_sampling=self.args.parallel_sampling,
      )

      solved_tours, ns = batched_two_opt_torch(
          np_points.astype("float64"), np.array(tours).astype('int64'),
          max_iterations=self.args.two_opt_iterations, device=device)
      stacked_tours.append(solved_tours)

    solved_tours = np.concatenate(stacked_tours, axis=0)

    tsp_solver = TSPLibEvaluator(np_points_ori)
    gt_cost = solution

    total_sampling = self.args.parallel_sampling * self.args.sequential_sampling
    all_solved_costs = [tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)]
    avg_solved_cost = np.min(all_solved_costs)
    gap = (avg_solved_cost - gt_cost) / gt_cost * 100

    # save tour
    if self.args.save_numpy_heatmap:
      best_tour_idx = np.argmin(all_solved_costs)
      best_tour = solved_tours[best_tour_idx]
      adj_mat = np.zeros((points.shape[1], points.shape[1]))
      for i in range(best_tour.shape[0]-1):
        adj_mat[best_tour[i], best_tour[i+1]] = 1
      self.run_save_numpy_heatmap(adj_mat, np_points, dataset_name, split)

    end_time = time.time()
    metrics = {
        f"{split}/gt_cost": gt_cost,
        f"{split}/avg_solved_cost": avg_solved_cost,
        f"{split}/avg_gap(%)": gap,
        f"{split}/Avg per-instance_runtime": (end_time-start_time)/ self.args.sequential_sampling
    }

    print("Dataset name: ", dataset_name)
    print(metrics)
    print()

    return metrics

  def run_save_numpy_heatmap(self, adj_mat, np_points, real_batch_idx, split):
    exp_save_dir = os.path.join(self.logger.save_dir, self.logger.name, str(self.logger.version))
    heatmap_path = os.path.join(exp_save_dir, 'numpy_heatmap')
    rank_zero_info(f"Saving heatmap to {heatmap_path}")
    os.makedirs(heatmap_path, exist_ok=True)
    real_batch_idx = real_batch_idx[0]
    np.save(os.path.join(heatmap_path, f"{split}-heatmap-{real_batch_idx}.npy"), adj_mat)
    np.save(os.path.join(heatmap_path, f"{split}-points-{real_batch_idx}.npy"), np_points)

  def validation_step(self, batch, batch_idx):
    return self.test_step(batch, batch_idx, split='val')
