"""Lightning module for training the DIFUSCO TSP model."""

import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from lightning.pytorch.utilities import rank_zero_info

from difusco.co_datasets.tsp_graph_dataset import TSPGraphDataset
from difusco.pl_meta_model import COMetaModel
from difusco.utils.diffusion_schedulers import InferenceSchedule
from difusco.utils.tsp_utils import (
    TSPEvaluator,
    batched_two_opt_torch,
    merge_tours,
    make_tour_to_graph,
)

from difusco.guides import ValueGuide
import time


class TSPReward_Weighted_Model(COMetaModel):
    def __init__(self, param_args=None, classifier=None):
        super(TSPReward_Weighted_Model, self).__init__(
            param_args=param_args, node_feature_only=False
        )

        self.train_dataset = TSPGraphDataset(
            data_file=os.path.join(self.args.storage_path, self.args.training_split),
            sparse_factor=self.args.sparse_factor,
        )

        self.test_dataset = TSPGraphDataset(
            data_file=os.path.join(self.args.storage_path, self.args.test_split),
            sparse_factor=self.args.sparse_factor,
        )

        self.validation_dataset = TSPGraphDataset(
            data_file=os.path.join(self.args.storage_path, self.args.validation_split),
            sparse_factor=self.args.sparse_factor,
        )
        if classifier is not None:
            self.classifier = classifier

    def forward(self, x, adj, t, edge_index):
        return self.model(x, t, adj, edge_index)

    # classifier training step

    def reward_weighted_training_step(self, batch, batch_idx):
        edge_index = None
        stacked_tours = []

        if not self.sparse:
            _, points, adj_matrix, _ = batch
            t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
        else:
            _, graph_data, point_indicator, edge_indicator, _ = batch
            t = np.random.randint(
                1, self.diffusion.T + 1, point_indicator.shape[0]
            ).astype(int)
            # t = np.random.randint(1, self.diffusion.T + 1, 1).repeat(point_indicator.shape[0]).astype(int)
            route_edge_flags = graph_data.edge_attr
            points = graph_data.x
            edge_index = graph_data.edge_index
            np_points = points.cpu().numpy()
            num_edges = edge_index.shape[1]
            batch_size = point_indicator.shape[0]
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
        
        t1 = torch.from_numpy(np.array([t[0]]).astype(int)).view(1)
        t2 = np.array([t[0] - 1]).astype(int)

        # Sample from diffusion
        adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()
        if self.sparse:
            adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

        xt = self.diffusion.sample(adj_matrix_onehot, t)
        xt = xt * 2 - 1
        xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

        # sample reward_guided_sample
        adj_matrix_onehot = adj_matrix_onehot.reshape(1, points.shape[0], -1, 2)
        graph_size = points.shape[0] // batch_size

        xt = (xt > 0).long()
        if self.sparse:
            xt = xt.reshape(-1)

        if self.sparse:
            t = torch.from_numpy(t).float()
            t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
            xt = xt.reshape(-1)
            adj_matrix = adj_matrix.reshape(-1)
            points = points.reshape(-1, 2)
            edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
        else:
            t = torch.from_numpy(t).float().view(adj_matrix.shape[0])

        # Denoise
        x0_pred = self.forward(
            points.float().to(adj_matrix.device),
            xt.float().to(adj_matrix.device),
            t.float().to(adj_matrix.device),
            edge_index,
        )

        if not self.sparse:
            x0_pred_prob = x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
        else:
            x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(dim=-1)

        x0_predict = x0_pred_prob[:, :, :, 1:].clamp(0, 1)
        x0_predict = x0_predict.reshape(-1)
        adj_mat = x0_predict.cpu().detach().numpy() + 1e-6

        edge_index = edge_index.reshape(2, batch_size, -1)
        revise_edge_idx = (
            torch.arange(0, batch_size)[None, :, None]
            .repeat(edge_index.shape[0], 1, edge_index.shape[2])
            .to(edge_index.device)
        )
        np_edge_index = (edge_index - graph_size * revise_edge_idx).cpu().numpy()
        np_edge_index = np_edge_index.reshape(2, -1).astype("int64")

        tours, merge_iterations = merge_tours(
            adj_mat,
            np_points,
            np_edge_index,
            sparse_graph=self.sparse,
            batch_size=batch_size,
            guided=True,
        )
        solved_tours, ns = batched_two_opt_torch(
            np_points.astype("float64"),
            np.array(tours).astype("int64"),
            max_iterations=2,
            device=adj_matrix.device,
            guided=True,
        )
        stacked_tours.append(solved_tours)
        solved_tours = np.concatenate(stacked_tours, axis=0)

        graph_data, _, _, _ = make_tour_to_graph(
            np_points, solved_tours, self.args.sparse_factor
        )
        route_edge_flags = graph_data.edge_attr
        new_adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
        new_adj_matrix = new_adj_matrix.to(adj_matrix.device)

        new_adj_matrix = new_adj_matrix.reshape(-1)

        # KL-divergence
        # new_adj_matrix = new_adj_matrix.reshape(batch_size, -1)
        # Compute loss
        loss_func = nn.CrossEntropyLoss()
        loss1 = loss_func(x0_pred, new_adj_matrix.long())

        # loss_func2 = nn.KLDivLoss(reduction="batchmean")
        # loss1 = loss_func2(xt_minus_1_guided_pred.reshape(batch_size,-1), new_adj_matrix.float())

        loss2 = loss_func(x0_pred, adj_matrix.long())
        loss = loss1 + loss2

        self.log("train/loss", loss)
        return loss

    def training_step(self, batch, batch_idx):
        return self.reward_weighted_training_step(batch, batch_idx)

    def categorical_denoise_step(
        self, points, xt, t, device, edge_index=None, target_t=None, classifier=None
    ):
        with torch.no_grad():
            t = torch.from_numpy(t).view(1)
            x0_pred = self.forward(
                points.float().to(device),
                xt.float().to(device) + classifier,
                t.float().to(device),
                edge_index.long().to(device) if edge_index is not None else None,
            )

            if not self.sparse:
                x0_pred_prob = (
                    x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
                )
            else:
                x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(
                    dim=-1
                )
            xt = self.categorical_posterior(
                target_t, t, x0_pred_prob, xt, classifier=classifier
            )
            return xt

    def gaussian_denoise_step(
        self, points, xt, t, device, edge_index=None, target_t=None
    ):
        with torch.no_grad():
            t = torch.from_numpy(t).view(1)
            pred = self.forward(
                points.float().to(device),
                xt.float().to(device),
                t.float().to(device),
                edge_index.long().to(device) if edge_index is not None else None,
            )
            pred = pred.squeeze(1)
            xt = self.gaussian_posterior(target_t, t, pred, xt)
            return xt

    def test_step(self, batch, batch_idx, split="test"):
        return self.denoise_test_step(batch, batch_idx, split)

    def denoise_test_step(self, batch, batch_idx, split="test"):
        edge_index = None
        np_edge_index = None
        device = batch[-1].device
        if not self.sparse:
            real_batch_idx, points, adj_matrix, gt_tour, _ = batch
            np_points = points.cpu().numpy()[0]
            np_gt_tour = gt_tour.cpu().numpy()[0]
        else:
            (
                real_batch_idx,
                graph_data,
                point_indicator,
                edge_indicator,
                gt_tour,
                _,
            ) = batch
            route_edge_flags = graph_data.edge_attr
            points = graph_data.x
            edge_index = graph_data.edge_index
            num_edges = edge_index.shape[1]
            batch_size = point_indicator.shape[0]
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
            points = points.reshape((-1, 2))
            edge_index = edge_index.reshape((2, -1))
            np_points = points.cpu().numpy()
            np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
            np_edge_index = edge_index.cpu().numpy()

        stacked_tours = []
        ns, merge_iterations = 0, 0

        if self.args.parallel_sampling > 1:
            if not self.sparse:
                points = points.repeat(self.args.parallel_sampling, 1, 1)
            else:
                points = points.repeat(self.args.parallel_sampling, 1)
                edge_index = self.duplicate_edge_index(
                    edge_index, np_points.shape[0], device
                )
        # print(points.shape)
        start_time = time.time()

        for _ in range(self.args.sequential_sampling):
            xt = torch.randn_like(adj_matrix.float())
            if self.args.parallel_sampling > 1:
                if not self.sparse:
                    xt = xt.repeat(self.args.parallel_sampling, 1, 1)
                else:
                    xt = xt.repeat(self.args.parallel_sampling, 1)
                    xt = torch.randn_like(xt)
                # print(xt.shape)
            if self.diffusion_type == "gaussian":
                xt.requires_grad = True
            else:
                xt = (xt > 0).long()

            if self.sparse:
                xt = xt.reshape(-1)

            steps = self.args.inference_diffusion_steps
            time_schedule = InferenceSchedule(
                inference_schedule=self.args.inference_schedule,
                T=self.diffusion.T,
                inference_T=steps,
            )

            # Diffusion iterations
            for i in range(steps):
                t1, t2 = time_schedule(i)
                t1 = np.array([t1]).astype(int)
                t2 = np.array([t2]).astype(int)

                if self.diffusion_type == "gaussian":
                    xt = self.gaussian_denoise_step(
                        points, xt, t1, device, edge_index, target_t=t2
                    )
                else:
                    xt = self.categorical_denoise_step(
                        points, xt, t1, device, edge_index, target_t=t2
                    )
            # print(xt.shape)
        if self.diffusion_type == "gaussian":
            adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
        else:
            adj_mat = xt.float().cpu().detach().numpy() + 1e-6

        # print(adj_mat.dtype)
        if self.args.save_numpy_heatmap:
            self.run_save_numpy_heatmap(adj_mat, np_points, real_batch_idx, split)

        diffusion_time = time.time()

        start = time.time()
        tours, merge_iterations = merge_tours(
            adj_mat,
            np_points,
            np_edge_index,
            sparse_graph=self.sparse,
            batch_size=self.args.batch_size,
        )

        greedy_end = time.time()

        ## without 2-opt
        tsp_solver = TSPEvaluator(np_points)
        wo_2opt_costs = tsp_solver.evaluate(tours[0])
        # print("without_2opt : ", wo_2opt_costs)
        # print("before_2opt-duration time: ", greedy_end - start_time)

        # Refine using 2-opt
        solved_tours, ns = batched_two_opt_torch(
            np_points.astype("float64"),
            np.array(tours).astype("int64"),
            max_iterations=self.args.two_opt_iterations,
            device=device,
        )
        stacked_tours.append(solved_tours)

        solved_tours = np.concatenate(stacked_tours, axis=0)

        tsp_solver = TSPEvaluator(np_points)
        gt_cost = tsp_solver.evaluate(np_gt_tour)

        opt_end = time.time()
        # print("2opt-duration : ", opt_end - start_time)

        total_sampling = self.args.parallel_sampling * self.args.sequential_sampling
        all_solved_costs = [
            tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)
        ]
        best_solved_cost = np.min(all_solved_costs)
        # print("best_solved_cost : ", best_solved_cost)
        metrics = {
            f"{split}/Diffusion inference duration": diffusion_time - start_time,
            f"{split}/Heatmap Sorting duration": greedy_end - diffusion_time,
            f"{split}/2opt duration": opt_end - greedy_end,
            f"{split}/wo_2opt_cost": wo_2opt_costs,
            f"{split}/gt_cost": gt_cost,
            # f"{split}/2opt_iterations": ns,
            # f"{split}/merge_iterations": merge_iterations,
        }
        for k, v in metrics.items():
            self.log(k, v, on_epoch=True, sync_dist=True)
        self.log(
            f"{split}/solved_cost",
            best_solved_cost,
            prog_bar=True,
            on_epoch=True,
            sync_dist=True,
        )
        return metrics

    def run_save_numpy_heatmap(self, adj_mat, np_points, real_batch_idx, split):
        if self.args.parallel_sampling > 1 or self.args.sequential_sampling > 1:
            raise NotImplementedError("Save numpy heatmap only support single sampling")
        exp_save_dir = os.path.join(
            self.logger.save_dir, self.logger.name, self.logger.version
        )
        heatmap_path = os.path.join(exp_save_dir, "numpy_heatmap")
        rank_zero_info(f"Saving heatmap to {heatmap_path}")
        os.makedirs(heatmap_path, exist_ok=True)
        real_batch_idx = real_batch_idx.cpu().numpy().reshape(-1)[0]
        np.save(
            os.path.join(heatmap_path, f"{split}-heatmap-{real_batch_idx}.npy"), adj_mat
        )
        np.save(
            os.path.join(heatmap_path, f"{split}-points-{real_batch_idx}.npy"),
            np_points,
        )

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        if self.classifier is not None:
            return self.classifier_test_step(
                self.classifier, batch, batch_idx, split="val"
            )
        elif self.diffusion_type == "classifier":
            return self.value_function_vaild_step(batch, batch_idx, split="val")
        else:
            return self.denoise_test_step(batch, batch_idx, split="val")
