"""Lightning module for training the DIFUSCO TSP model."""

import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from lightning.pytorch.utilities import rank_zero_info

from co_datasets.tsp_graph_dataset import TSPGraphDataset
from pl_meta_model import COMetaModel
from utils.diffusion_schedulers import InferenceSchedule
from utils.tsp_utils import (
    TSPEvaluator,
    batched_two_opt_torch,
    merge_tours,
    make_tour_to_graph,
)

from guides import ValueGuide
import time


class TSPGuided_Graident_Model(COMetaModel):
    def __init__(self, param_args=None, classifier=None):
        super(TSPGuided_Graident_Model, self).__init__(
            param_args=param_args, node_feature_only=False
        )

        self.train_dataset = TSPGraphDataset(
            data_file=os.path.join(self.args.storage_path, self.args.training_split),
            sparse_factor=self.args.sparse_factor,
        )

        self.test_dataset = TSPGraphDataset(
            data_file=os.path.join(self.args.storage_path, self.args.test_split),
            sparse_factor=self.args.sparse_factor,
        )

        self.validation_dataset = TSPGraphDataset(
            data_file=os.path.join(self.args.storage_path, self.args.validation_split),
            sparse_factor=self.args.sparse_factor,
        )
        # if classifier is not None:
        #     self.classifier = classifier

    def forward(self, x, adj, t, edge_index):
        return self.model(x, t, adj, edge_index)

    # classifier training step

    def classifier_training_step(self, classifier_model, batch, batch_idx):
        edge_index = None
        np_edge_index = None
        stacked_tours = []

        if not self.sparse:
            _, points, adj_matrix, _, cost = batch
            # t = np.random.randint(1, self.diffusion.T + 1, 1).repeat(points.shape[0]).astype(int)
            t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
            np_points = points.cpu().numpy()
            batch_size = points.shape[0]
        else:
            _, graph_data, point_indicator, edge_indicator, _, _ = batch
            t = np.random.randint(
                1, self.diffusion.T + 1, point_indicator.shape[0]
            ).astype(int)
            route_edge_flags = graph_data.edge_attr
            points = graph_data.x
            edge_index = graph_data.edge_index
            np_points = points.cpu().numpy()
            num_edges = edge_index.shape[1]
            batch_size = point_indicator.shape[0]
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))

        adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()

        if self.sparse:
            adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

        # Sample from diffusion
        xt = self.diffusion.sample(adj_matrix_onehot, t)
        xt = (xt > 0).long()
        if self.diffusion_space == "discrete":
            xt = (xt > 0).long()
            adj_mat = xt.float().cpu().detach().numpy() + 1e-6
        else:
            adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5

        graph_size = points.shape[0] // batch_size

        if self.sparse:
            xt = xt.reshape(-1)
            t = torch.from_numpy(t).float()
            t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
            xt = xt.reshape(-1)
            adj_matrix = adj_matrix.reshape(-1)
            points = points.reshape(-1, 2)
            edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
            edge_index = edge_index.reshape(2, batch_size, -1)
            revise_edge_idx = (
                torch.arange(0, batch_size)[None, :, None]
                .repeat(edge_index.shape[0], 1, edge_index.shape[2])
                .to(edge_index.device)
            )
            np_edge_index = (edge_index - graph_size * revise_edge_idx).cpu().numpy()
            np_edge_index = np_edge_index.reshape(2, -1).astype("int64")
        else:
            t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
        # Denoise
        # parm_p = (parm_p + parm_p.transpose(1,2))/2
        # parm_p = parm_p + parm_p.
        pred_edge = self.forward(
            points.float().to(adj_matrix.device),
            xt.float().to(adj_matrix.device),
            t.float().to(adj_matrix.device),
            edge_index,
        )

        tours, merge_iterations = merge_tours(
            adj_mat,
            np_points,
            np_edge_index,
            sparse_graph=self.sparse,
            parallel_sampling=batch_size,
            guided=True,
        )

        solved_tours, ns = batched_two_opt_torch(
                    np_points.astype("float64"),
                    np.array(tours).astype("int64"),
                    max_iterations=self.args.two_opt_iterations,
                    device=adj_matrix.device,
                    batch=True,
                )
        stacked_tours.append(solved_tours)
        solved_tours = np.concatenate(stacked_tours, axis=0)
        route_indices = torch.from_numpy(solved_tours).to(adj_matrix.device)
        # best_solved_cost = tsp_solver.evaluate(route_indices).mean()

        # print("after_2-opt_cost : ", best_solved_cost)
      # Calculate the indices for the start and end cities in the route
        start_cities = route_indices[:, :-1]
        end_cities = route_indices[:, 1:]
        batch_indices = torch.arange(route_indices.size(0))
        opt_adj_matrix = torch.zeros((points.shape[0], points.shape[1], points.shape[1])).to(adj_matrix.device)
        opt_adj_matrix[batch_indices[:, None], start_cities, end_cities] = 1
        
        # graph_data, _, _, _ = make_tour_to_graph(
        #     np_points, solved_tours, self.args.sparse_factor
        # )
        # route_edge_flags = graph_data.edge_attr
        # new_adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
        # new_adj_matrix = new_adj_matrix.to(adj_matrix.device)

        # new_adj_matrix = new_adj_matrix.reshape(-1)

        # KL-divergence
        # new_adj_matrix = new_adj_matrix.reshape(batch_size, -1)
        # Compute loss
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(pred_edge, opt_adj_matrix.long())
        # solved_tours = torch.tensor(tours)
        # if self.sparse:
        #     splitted_points = np.split(np_points, batch_size, axis=0)
        #     tsp_solver = TSPEvaluator(splitted_points, batch=True)
        # else:
        #     tsp_solver = TSPEvaluator(np_points, batch=True)
        # # print("here3")
        # true_cost = tsp_solver.evaluate(solved_tours)
        # loss_func = nn.MSELoss()
        # loss = loss_func(value_pred, true_cost.reshape(-1, 1).float().to(adj_matrix.device))

        self.log("train/loss", loss)
        return loss

    def training_step(self, batch, batch_idx):
        return self.classifier_training_step(self.classifier_model, batch, batch_idx)

    def categorical_denoise_step(
        self, points, xt, t, device, edge_index=None, target_t=None, classifier=None
    ):
        classifier=None
        with torch.no_grad():
            t = torch.from_numpy(t).view(1)
            x0_pred = self.forward(
                points.float().to(device),
                xt.float().to(device),
                t.float().to(device),
                edge_index.long().to(device) if edge_index is not None else None,
            )
            # if classifier:
            # # xt_hat = xt.clone()
            # # xt_hat2 = (xt_hat + xt_hat.transpose(1,2)) / 2
            #     grad = classifier.gradients(
            #         points.to(device),
            #         t.float().to(device),
            #         xt.float().to(device),
            #         edge_index.long().to(device)
            #         if edge_index is not None
            #         else None,
            # )
        # print(grad.max())
        # x0_pred =  0.5 * x0_pred + 0.5 * grad
        if not self.sparse:
            x0_pred_prob = (
                x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
            )
            # # grad_prob = grad.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
            # xt = grad_prob[..., 1]
        else:
            x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(
                dim=-1
            )
        # classifier = None
        xt = self.categorical_posterior(
            target_t, t, x0_pred_prob, xt, classifier=classifier
        )

        # xt_prime = grad_prob[..., 1]

        # mu = torch.randn(xt.shape[0]).to(device)
        # xt[mu > 0.1,:,:] = xt_prime[mu > 0.5,:,:]
        # if target_t > 0:
        #     xt_prime = xt_prime * 2 - 1
        #     xt = 0.01 *xt_prime + xt
        #     xt = ((xt) > 0).long()
        
        # if target_t > 0:
        #     sum_x_t_target_prob = 0.99 * xt + 0.01 * xt_prime
        #     xt = torch.bernoulli(sum_x_t_target_prob.clamp(0, 1))
        # else:
        #     sum_x_t_target_prob = xt
        #     xt = sum_x_t_target_prob.clamp(min=0)
        return xt
    def gaussian_denoise_step(
        self, points, xt, t, device, edge_index=None, target_t=None
    ):
        with torch.no_grad():
            t = torch.from_numpy(t).view(1)
            pred = self.forward(
                points.float().to(device),
                xt.float().to(device),
                t.float().to(device),
                edge_index.long().to(device) if edge_index is not None else None,
            )
            pred = pred.squeeze(1)
            xt = self.gaussian_posterior(target_t, t, pred, xt)
            return xt

    def test_step(self, batch, batch_idx, split="test"):
        torch.set_grad_enabled(True)
        return self.classifier_test_step(self.classifier_model, batch, batch_idx, split)

    def classifier_test_step(self, classifier_model, batch, batch_idx, split="test"):
        # debugging mode
        # value_guided_model = TSPModel.load_from_checkpoint("project/diffusion/result/tsp_result/train_classifier/my_checkpoints2/last.ckpt", param_args=self.args)

        edge_index = None
        np_edge_index = None
        device = batch[-1].device
        stacked_tours = []

        value_guided_model = classifier_model
        model = value_guided_model.model
        classifier = ValueGuide(model)

        if not self.sparse:
            real_batch_idx, points, adj_matrix, gt_tour, _ = batch
            t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
            np_points = points.cpu().numpy()
            batch_size = points.shape[0]
            if batch_size == 1:
                np_gt_tour = gt_tour.cpu().numpy()[0]
            else:
                np_gt_tour = gt_tour

        else:
            real_batch_idx, graph_data, point_indicator, edge_indicator, _, _ = batch
            t = np.random.randint(
                1, self.diffusion.T + 1, point_indicator.shape[0]
            ).astype(int)
            route_edge_flags = graph_data.edge_attr
            points = graph_data.x
            edge_index = graph_data.edge_index
            np_points = points.cpu().numpy()
            num_edges = edge_index.shape[1]
            batch_size = point_indicator.shape[0]
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))

        stacked_tours = []
        ns, merge_iterations = 0, 0

        for _ in range(self.args.sequential_sampling):
            xt = torch.randn_like(adj_matrix.float())
            if self.args.parallel_sampling > 1:
                if not self.sparse:
                    xt = xt.repeat(self.args.parallel_sampling, 1, 1)
                else:
                    xt = xt.repeat(self.args.parallel_sampling, 1)
                xt = torch.randn_like(xt)
                # print(xt.shape)
            if self.diffusion_type == "gaussian":
                xt.requires_grad = True
            else:
                xt = (xt > 0).long()

            if self.sparse:
                xt = xt.reshape(-1)

            steps = self.args.inference_diffusion_steps
            time_schedule = InferenceSchedule(
                inference_schedule=self.args.inference_schedule,
                T=self.diffusion.T,
                inference_T=steps,
            )
            # Diffusion iterations
            for i in range(steps):
                t1, t2 = time_schedule(i)
                t1 = np.array([t1]).astype(int)
                t2 = np.array([t2]).astype(int)
                if t2 > 0:
                    grad = classifier.gradients(
                    points.to(device),
                    torch.from_numpy(t1).float().to(device),
                    xt.float().to(device),
                    edge_index.long().to(device)
                    if edge_index is not None
                    else None,
                    )
                    grad_prob = grad.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
                    xt_prime = grad_prob[..., 1]
                    xt_prime = torch.bernoulli(xt_prime.clamp(0, 1))
                # grad_prob = grad.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
                # xt = grad_prob[..., 1]
                # xt = torch.bernoulli(xt.clamp(0, 1))
                if self.diffusion_space == "continuous":
                    xt = self.gaussian_denoise_step(
                        points, xt, t1, device, edge_index, target_t=t2
                    )
                else:
                    xt = self.categorical_denoise_step(
                        points, xt, t1, device, edge_index, target_t=t2, classifier=classifier
                    )
                    if t2 > 0:
                        # print(xt_prime.shape)
                        # print(xt.shape)
                        xt_1 = (xt.float() + (xt_prime * 2 - 1).float())/2
            if self.diffusion_space == "continuous":
                adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
            else:
                adj_mat = xt_1.float().detach().cpu().numpy() + 1e-6

            if self.args.save_numpy_heatmap:
                self.run_save_numpy_heatmap(adj_mat, np_points, real_batch_idx, split)

            tours, merge_iterations = merge_tours(
                adj_mat,
                np_points,
                np_edge_index,
                sparse_graph=self.sparse,
                parallel_sampling=batch_size,
                guided=True,
            )
            ## without 2-opt
            if batch_size == 1:
                tsp_solver = TSPEvaluator(np_points[0])
                wo_2opt_costs = tsp_solver.evaluate(tours[0])
                # print("without_2opt : ", wo_2opt_costs)

                # Refine using 2-opt
                solved_tours, ns = batched_two_opt_torch(
                    np_points[0].astype("float64"),
                    np.array(tours).astype("int64"),
                    max_iterations=self.args.two_opt_iterations,
                    device=device,
                )

                stacked_tours.append(solved_tours)
                solved_tours = np.concatenate(stacked_tours, axis=0)

                tsp_solver = TSPEvaluator(np_points[0])
                gt_cost = tsp_solver.evaluate(np_gt_tour)

                total_sampling = (
                    self.args.parallel_sampling * self.args.sequential_sampling
                )
                all_solved_costs = [
                    tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)
                ]
                best_solved_cost = np.min(all_solved_costs)
            else:
                # calculate before 2-opt cost

                tsp_solver = TSPEvaluator(np_points, batch=True)
                tours = torch.tensor(tours)
                wo_2opt_costs = tsp_solver.evaluate(tours).mean()
                # print("without_2opt : ", wo_2opt_costs)

                # Refine using 2-opt
                solved_tours, ns = batched_two_opt_torch(
                    np_points.astype("float64"),
                    np.array(tours).astype("int64"),
                    max_iterations=self.args.two_opt_iterations,
                    device=device,
                    batch=True,
                )
                stacked_tours.append(solved_tours)
                solved_tours = np.concatenate(stacked_tours, axis=0)
                solved_tours = torch.tensor(solved_tours)

                gt_cost = tsp_solver.evaluate(gt_tour).mean()
                best_solved_cost = tsp_solver.evaluate(solved_tours).mean()

        metrics = {
            f"{split}/wo_2opt_cost": wo_2opt_costs,
            f"{split}/gt_cost": gt_cost,
        }
        for k, v in metrics.items():
            self.log(k, v, on_epoch=True, sync_dist=True)
        self.log(
            f"{split}/solved_cost",
            best_solved_cost,
            prog_bar=True,
            on_epoch=True,
            sync_dist=True,
        )
        return metrics

    def value_function_vaild_step(self, classifier_model, batch, batch_idx, split="val"):
        
        edge_index = None
        np_edge_index = None
        stacked_tours = []

        if not self.sparse:
            _, points, adj_matrix, _, cost = batch
            # t = np.random.randint(1, self.diffusion.T + 1, 1).repeat(points.shape[0]).astype(int)
            t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
            np_points = points.cpu().numpy()
            batch_size = points.shape[0]
        else:
            _, graph_data, point_indicator, edge_indicator, _, _ = batch
            t = np.random.randint(
                1, self.diffusion.T + 1, point_indicator.shape[0]
            ).astype(int)
            route_edge_flags = graph_data.edge_attr
            points = graph_data.x
            edge_index = graph_data.edge_index
            np_points = points.cpu().numpy()
            num_edges = edge_index.shape[1]
            batch_size = point_indicator.shape[0]
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))

        adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()

        if self.sparse:
            adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

        # Sample from diffusion
        xt = self.diffusion.sample(adj_matrix_onehot, t)
        xt = (xt > 0).long()
        if self.diffusion_space == "discrete":
            xt = (xt > 0).long()
            adj_mat = xt.float().cpu().detach().numpy() + 1e-6
        else:
            adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5

        graph_size = points.shape[0] // batch_size

        if self.sparse:
            xt = xt.reshape(-1)
            t = torch.from_numpy(t).float()
            t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
            xt = xt.reshape(-1)
            adj_matrix = adj_matrix.reshape(-1)
            points = points.reshape(-1, 2)
            edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
            edge_index = edge_index.reshape(2, batch_size, -1)
            revise_edge_idx = (
                torch.arange(0, batch_size)[None, :, None]
                .repeat(edge_index.shape[0], 1, edge_index.shape[2])
                .to(edge_index.device)
            )
            np_edge_index = (edge_index - graph_size * revise_edge_idx).cpu().numpy()
            np_edge_index = np_edge_index.reshape(2, -1).astype("int64")
        else:
            t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
        # Denoise
        # parm_p = (parm_p + parm_p.transpose(1,2))/2
        # parm_p = parm_p + parm_p.
        pred_edge = self.forward(
            points.float().to(adj_matrix.device),
            xt.float().to(adj_matrix.device),
            t.float().to(adj_matrix.device),
            edge_index,
        )

        tours, merge_iterations = merge_tours(
            adj_mat,
            np_points,
            np_edge_index,
            sparse_graph=self.sparse,
            parallel_sampling=batch_size,
            guided=True,
        )

        # tsp_solver = TSPEvaluator(np_points, batch=True)
        # tours = torch.tensor(tours)
        # wo_2opt_costs = tsp_solver.evaluate(tours).mean()
        # print("without_2opt : ", wo_2opt_costs)

        # Refine using 2-opt

        solved_tours, ns = batched_two_opt_torch(
                    np_points.astype("float64"),
                    np.array(tours).astype("int64"),
                    max_iterations=self.args.two_opt_iterations,
                    device=adj_matrix.device,
                    batch=True,
                )
        stacked_tours.append(solved_tours)
        solved_tours = np.concatenate(stacked_tours, axis=0)
        route_indices = torch.from_numpy(solved_tours).to(adj_matrix.device)
        # best_solved_cost = tsp_solver.evaluate(route_indices).mean()

        # print("after_2-opt_cost : ", best_solved_cost)
      # Calculate the indices for the start and end cities in the route
        start_cities = route_indices[:, :-1]
        end_cities = route_indices[:, 1:]
        batch_indices = torch.arange(route_indices.size(0))
        opt_adj_matrix = torch.zeros((points.shape[0], points.shape[1], points.shape[1])).to(adj_matrix.device)
        opt_adj_matrix[batch_indices[:, None], start_cities, end_cities] = 1

        



     
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(pred_edge, opt_adj_matrix.long())



        self.log(
            f"{split}/solved_cost", loss, prog_bar=True, on_epoch=True, sync_dist=True
        )
        return loss

    def run_save_numpy_heatmap(self, adj_mat, np_points, real_batch_idx, split):
        if self.args.parallel_sampling > 1 or self.args.sequential_sampling > 1:
            raise NotImplementedError("Save numpy heatmap only support single sampling")
        exp_save_dir = os.path.join(
            self.logger.save_dir, self.logger.name, self.logger.version
        )
        heatmap_path = os.path.join(exp_save_dir, "numpy_heatmap")
        rank_zero_info(f"Saving heatmap to {heatmap_path}")
        os.makedirs(heatmap_path, exist_ok=True)
        real_batch_idx = real_batch_idx.cpu().numpy().reshape(-1)[0]
        np.save(
            os.path.join(heatmap_path, f"{split}-heatmap-{real_batch_idx}.npy"), adj_mat
        )
        np.save(
            os.path.join(heatmap_path, f"{split}-points-{real_batch_idx}.npy"),
            np_points,
        )

    def validation_step(self, batch, batch_idx):
        if self.args.do_train:
            return self.value_function_vaild_step(self.classifier_model, batch, batch_idx, split="val")
        elif self.classifier is not None:
            torch.set_grad_enabled(True)
            return self.classifier_test_step(
                self.classifier, batch, batch_idx, split="val"
            )
        elif self.diffusion_type == "classifier":
            return self.value_function_vaild_step(batch, batch_idx, split="val")
