"""Lightning module for training the DIFUSCO TSP model."""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from lightning.pytorch.utilities import rank_zero_info
import time
from difusco.co_datasets.tsp_graph_dataset import TSPGraphDataset
from difusco.pl_meta_model import COMetaModel
from difusco.utils.diffusion_schedulers import InferenceSchedule
from difusco.utils.tsp_utils import TSPEvaluator, batched_two_opt_torch, merge_tours
from difusco.pl_tsp_model import TSPModel
import time
import torch._dynamo
torch._dynamo.config.suppress_errors = True

import pdb

class TSPModelFreeGuide(TSPModel):
    def __init__(self,
                             param_args=None):
        super(TSPModelFreeGuide, self).__init__(param_args=param_args)
        self.cost_mean = self.train_dataset.cost_mean
        self.cost_std = self.train_dataset.cost_std
        self.cost_min = self.cost_mean-2*self.cost_std
        self.cost_max = self.cost_mean+2*self.cost_std
        self.relabel_count = 0
    def cost_normalize(self,cost) :
        #0~1
        cost = (cost - self.cost_min)/(self.cost_max-self.cost_min)
        #-1~0
        cost = cost -1
        #cost = (cost - self.cost_mean)/self.cost_std
        return cost
    def relabel_dataset(self,x0_pred,xt_,t,points,batch_idx,cost):
        edge_index=None
        if not self.sparse:
            x0_pred_prob = x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
        else:
            x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(dim=-1)
            edge_index = edge_index.cpu().numpy()
        xt = self.categorical_posterior(None, t.long(), x0_pred_prob, xt_,training=True)
        
        adj_mat = xt.float().cpu().detach().numpy() + 1e-6
        tours, merge_iterations = merge_tours(
                    adj_mat, points, edge_index,
                    sparse_graph=self.sparse,
                    parallel_sampling=points.shape[0],
                    guided=True
            )
        tsp_solvers = [] # generate tsp_solvers
        for i in range(0, points.shape[0]) :
            tsp_solver = TSPEvaluator(points[i])
            wo_2opt_costs = tsp_solver.evaluate(tours[i])
            if wo_2opt_costs < cost[i][0] :
                self.train_dataset.data[batch_idx[i][0]][2] = wo_2opt_costs
                self.train_dataset.data[batch_idx[i][0]][1] = np.array(tours[i])
                self.relabel_count += 1
        
        return None 
    def forward(self, x, adj, t, edge_index, returns, use_dropout, force_dropout, opt_dropout):
        return self.model(x, t, adj, edge_index,returns, use_dropout, force_dropout, opt_dropout)

    # def categorical_training_step_two_opt_target(self, batch, batch_idx):
    #     edge_index = None
    #     if not self.sparse:
    #         _, points, adj_matrix, _, cost = batch
    #         t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
    #         batch_size = points.shape[0]
    #         np_edge_index = None
    #         original_edge_index = None
    #     else:
    #         _, graph_data, point_indicator, edge_indicator, _, cost= batch
    #         t = np.random.randint(1, self.diffusion.T + 1, point_indicator.shape[0]).astype(int)
    #         route_edge_flags = graph_data.edge_attr
    #         points = graph_data.x
    #         edge_index = graph_data.edge_index
    #         num_edges = edge_index.shape[1]
    #         batch_size = point_indicator.shape[0]
    #         adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
    #         edge_index = edge_index.reshape((2, -1))
    #         points = points.reshape((-1, 2))
    #         np_edge_index = edge_index.cpu().numpy()
    #         original_edge_index = edge_index.clone()
    #         # original_edge_index    = original_edge_index

    #     # time use for sub trajectory
    #     t_sub = t.copy()
    #     np_points = points.cpu().numpy()
        
    #     # Sample from diffusion
    #     if self.sparse:
    #             splitted_points = np.split(np_points, batch_size, axis=0)
    #             tsp_solver = TSPEvaluator(np.array(splitted_points), batch=True)
    #     else:
    #             tsp_solver = TSPEvaluator(np_points, batch=True)
        
    #     opt_xt = adj_matrix.clone()
    #     if self.sparse:
    #             opt_xt = opt_xt.reshape(-1)
    #             # adj_matrix = adj_matrix.reshape(-1)
    #     tours, _ = merge_tours(
    #                     opt_xt.cpu().detach().numpy(),
    #                     np_points,
    #                     np_edge_index,
    #                     sparse_graph=self.sparse,
    #                     parallel_sampling=batch_size,
    #                     guided=True)
        
    #     solved_tours = torch.tensor(tours)
    #     true_cost = tsp_solver.evaluate(solved_tours)
        
    #     # calculate the optimal gap
    #     opt_gap = calculate_gap(true_cost, true_cost)
    #     opt_gap= opt_gap.reshape(-1, 1)

    #     adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()
    #     if self.sparse:
    #         adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

    #     xt = self.diffusion.sample(adj_matrix_onehot, t)
    #     xt = xt * 2 - 1
    #     xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

    #     if self.sparse:
    #         t = torch.from_numpy(t).float()
    #         t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
    #         xt = xt.reshape(-1)
    #         adj_matrix = adj_matrix.reshape(-1)
    #         points = points.reshape(-1, 2)
    #         edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
    #     else:
    #         t = torch.from_numpy(t).float().view(adj_matrix.shape[0])

    #     # Denoise
    #     x0_pred = self.forward(
    #             points.float().to(adj_matrix.device),
    #             xt.float().to(adj_matrix.device),
    #             t.float().to(adj_matrix.device),
    #             edge_index,
    #             returns = opt_gap.to(adj_matrix.device),
    #             use_dropout=True,
    #             force_dropout = False,
    #     )
    #     if not self.sparse:
    #             x0_pred_prob = x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
    #     else:
    #             x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(dim=-1)
        
    #     # sample from the posterior
    #     xt = x0_pred_prob[:,:,:,1].clamp(min=0)
    #     if self.sparse:
    #         xt = xt.reshape(-1)
    #     adj_mat = xt.float() + 1e-6
    #     if not self.sparse:
    #         # batched distance matrix
    #         dis_matrix = tsp_solver.dist_mat
    #         adj_mat = adj_mat * dis_matrix.to(adj_matrix.device)
    #     else:
    #         dis_matrix = torch.sqrt(torch.sum((points[original_edge_index.T[:, 0]] - points[original_edge_index.T[:, 1]]) ** 2, dim=1))
    #         adj_mat = adj_mat * dis_matrix.to(adj_matrix.device)
    #         # dis_matrix = dis_matrix.reshape((1, points.shape[0], -1))
        
    #     adj_mat = adj_mat.cpu().detach().numpy()
        
    #     tours, _ = merge_tours(
    #                     adj_mat,
    #                     np_points,
    #                     np_edge_index,
    #                     sparse_graph=self.sparse,
    #                     parallel_sampling=batch_size,
    #                     guided=True)

    #     solved_tours, _ = batched_two_opt_torch(
    #                         np_points.astype("float64"),
    #                         np.array(tours).astype("int64"),
    #                         max_iterations=self.args.two_opt_iterations,
    #                         device=adj_matrix.device,
    #                         batch=True,
    #                 )
        
    #     stacked_tours = []
    #     stacked_tours.append(solved_tours)
    #     solved_tours = np.concatenate(stacked_tours, axis=0)
    #     solved_tours = torch.tensor(solved_tours)
    #     sub_cost = tsp_solver.evaluate(solved_tours)
    #     sub_adjmatrix = []

    #     if self.sparse:
    #         original_edge_index = original_edge_index % 500
    #         original_edge_index = np.split(original_edge_index, batch_size, axis=1)

    #         for solved_tour, point in zip(solved_tours, splitted_points):
    #             sub_adjmatrix.append(tour2adj(solved_tour, point, sparse=self.sparse, sparse_factor=50, edge_index=original_edge_index%500))
    #     else:
    #         for solved_tour, point in zip(solved_tours, points):
    #             sub_adjmatrix.append(tour2adj(solved_tour, point, sparse=self.sparse, sparse_factor=50, edge_index=original_edge_index))
        
    #     # Make adjmatrix from current policy
    #     sub_adjmatrix = torch.stack(sub_adjmatrix, dim=0)
    #     sub_adj_matrix_onehot = F.one_hot(sub_adjmatrix.long(), num_classes=2).float()
    #     if self.sparse:
    #         sub_adj_matrix_onehot = sub_adj_matrix_onehot.unsqueeze(1)

    #     # t2 = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
    #     xt2 = self.diffusion.sample(sub_adj_matrix_onehot, t_sub)
    #     xt2 = xt2 * 2 - 1
    #     xt2 = xt2 * (1.0 + 0.05 * torch.rand_like(xt2))

    #     if self.sparse:
    #         t_sub = torch.from_numpy(t_sub).float()
    #         t_sub = t_sub.reshape(-1, 1).repeat(1, sub_adjmatrix.shape[1]).reshape(-1)
    #         xt2 = xt2.reshape(-1)
    #         sub_adjmatrix = sub_adjmatrix.reshape(-1)
    #         # points = points.reshape(-1, 2)
    #         # edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
    #     else:
    #         t_sub = torch.from_numpy(t_sub).float().view(sub_adjmatrix.shape[0])

    #     gap = calculate_gap(sub_cost, true_cost)
    #     gap = gap.reshape(-1, 1)

    #     x0_pred2 = self.forward(
    #             points.float().to(adj_matrix.device),
    #             xt2.float().to(adj_matrix.device),
    #             t_sub.float().to(adj_matrix.device),
    #             edge_index,
    #             returns = gap.to(adj_matrix.device),
    #             use_dropout=False,
    #             force_dropout = False,
    #     )
    #     total_pred = torch.cat([x0_pred,x0_pred2],dim=0)
    #     total_adj = torch.cat([adj_matrix,sub_adjmatrix.to(adj_matrix.device)],dim=0)

    #     loss_func = nn.CrossEntropyLoss()
    #     loss = loss_func(total_pred, total_adj.long())
    #     self.log("train/opt_gap", (sub_cost/true_cost).mean())
    #     self.log("train/loss", loss)
    #     return loss
    # @profile
    def categorical_training_step(self, batch, batch_idx):
        edge_index = None
        
        if not self.sparse:
            if self.optimality_dropout :
                batch_idx, points, adj_matrix, _,cost,optimality_dropout = batch
            else :
                batch_idx, points, adj_matrix, _,cost = batch
            t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
        else:
            batch_idx, graph_data, point_indicator, edge_indicator, _,cost= batch
            t = np.random.randint(1, self.diffusion.T + 1, point_indicator.shape[0]).astype(int)
            route_edge_flags = graph_data.edge_attr
            points = graph_data.x
            edge_index = graph_data.edge_index
            num_edges = edge_index.shape[1]
            batch_size = point_indicator.shape[0]
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
        original_cost = cost.clone().detach()
        if self.return_condition and not self.args.cost_category:
            cost = self.cost_normalize(cost)
            
        adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()
        if self.sparse:
            adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

        xt = self.diffusion.sample(adj_matrix_onehot, t)
        if self.relabel_epoch>0:
            xt_ = xt.clone().detach()
        xt = xt * 2 - 1
        xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

        if self.sparse:
            t = torch.from_numpy(t).float()
            t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
            xt = xt.reshape(-1)
            adj_matrix = adj_matrix.reshape(-1)
            points = points.reshape(-1, 2)
            edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
            # xt = xt.reshape(batch_size, -1)
            # adj_matrix = adj_matrix.reshape(batch_size, -1)
            # points = points.reshape(batch_size, -1, 2)
            # edge_index = edge_index.float().to(adj_matrix.device).reshape(batch_size, 2, -1)
        else:
            t = torch.from_numpy(t).float().view(adj_matrix.shape[0])

        # Denoise
        if self.optimality_dropout :
            x0_pred = self.forward(
                points.float().to(adj_matrix.device),
                xt.float().to(adj_matrix.device),
                t.float().to(adj_matrix.device),
                edge_index,
                returns = cost.to(adj_matrix.device),
                use_dropout=False,
                force_dropout = False,
                opt_dropout = optimality_dropout
        )
        else :
            x0_pred = self.forward(
                    points.float().to(adj_matrix.device),
                    xt.float().to(adj_matrix.device),
                    t.float().to(adj_matrix.device),
                    edge_index,
                    returns = cost.to(adj_matrix.device),
                    use_dropout=True,
                    force_dropout = False,
                    opt_dropout=None
                    
            )
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(x0_pred, adj_matrix.long())
        if self.relabel_epoch>0 and (((self.current_epoch+1) % self.relabel_epoch) == 0):
            self.relabel_dataset(x0_pred,xt_,t,points.cpu().numpy(),batch_idx.cpu().numpy(),original_cost.cpu().numpy())
        if self.args.train_unconditional: # add trainable loss
            x0_pred_uncond = self.forward(
                points.float().to(adj_matrix.device),
                xt.float().to(adj_matrix.device),
                t.float().to(adj_matrix.device),
                edge_index,
                returns = cost.to(adj_matrix.device),
                use_dropout=True,
                force_dropout = True,
            )
            loss_uncond = loss_func(x0_pred_uncond, adj_matrix.long())
            loss += loss_uncond
            loss *= 0.5
        # print(x0_pred.shape)
        # print(adj_matrix.shape)
        # Compute loss

    
        self.log("train/loss", loss)
        return loss

    def gaussian_training_step(self, batch, batch_idx):
        if self.sparse:
            # TODO: Implement Gaussian diffusion with sparse graphs
            raise ValueError("DIFUSCO with sparse graphs are not supported for Gaussian diffusion")
        _, points, adj_matrix, _, cost = batch

        if self.return_condition and not self.args.cost_category:
            cost = self.cost_normalize(cost)
        adj_matrix = adj_matrix * 2 - 1
        adj_matrix = adj_matrix * (1.0 + 0.05 * torch.rand_like(adj_matrix))
        # Sample from diffusion
        t = np.random.randint(1, self.diffusion.T + 1, adj_matrix.shape[0]).astype(int)
        xt, epsilon = self.diffusion.sample(adj_matrix, t)

        t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
        # Denoise
        
        #use_dropout = train uncond/cond together
        epsilon_pred = self.forward(
                points.float().to(adj_matrix.device),
                xt.float().to(adj_matrix.device),
                t.float().to(adj_matrix.device),
                edge_index = None,
                returns = cost.to(adj_matrix.device),
                use_dropout=True,
                force_dropout = False,
        )
 
        
        epsilon_pred = epsilon_pred.squeeze(1)

        # Compute loss
        loss = F.mse_loss(epsilon_pred, epsilon.float())

        if self.args.train_unconditional: # add trainable loss
            epsilon_pred_uncond = self.forward(
                    points.float().to(adj_matrix.device),
                    xt.float().to(adj_matrix.device),
                    t.float().to(adj_matrix.device),
                    edge_index = None,
                    returns = cost.to(adj_matrix.device),
                    use_dropout=True,
                    force_dropout = True,
            )
            loss_uncond = F.mse_loss(epsilon_pred_uncond, adj_matrix.long())
            loss += loss_uncond
            loss *=0.5

        self.log("train/loss", loss)
        return loss
    
    # def categorical_denoise_step(self, points, xt, t, device, edge_index=None, target_t=None, classifier=None,returns=None):
    #     with torch.no_grad():
    #         t = torch.from_numpy(t).view(1)

    #         x0_pred_uncond = self.forward(
    #                 points.float().to(device),
    #                 (xt).float().to(device),
    #                 t.float().to(device),
    #                 edge_index.long().to(device) if edge_index is not None else None,
    #                 returns = returns.to(device),
    #                 use_dropout=False,
    #                 force_dropout = True,
    #         )
    #         x0_pred_cond = self.forward(
    #                 points.float().to(device),
    #                 (xt).float().to(device),
    #                 t.float().to(device),
    #                 edge_index.long().to(device) if edge_index is not None else None,
    #                 returns = returns.to(device),
    #                 use_dropout=False,
    #                 force_dropout = False,
    #         )
    #         x0_pred = x0_pred_uncond + self.condition_guidance_w * (x0_pred_cond - x0_pred_uncond)
    #         if not self.sparse:
    #             x0_pred_prob = x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
    #         else:
    #             x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(dim=-1)

    #         xt = self.categorical_posterior(target_t, t, x0_pred_prob, xt)

    #         return xt#,xt_prob

    # def gaussian_denoise_step(self, points, xt, t, device, edge_index=None, target_t=None,returns=None):
    #     with torch.no_grad():
    #         t = torch.from_numpy(t).view(1)

    #         epsilon_pred_uncond = self.forward(
    #             points.float().to(device),
    #             xt.float().to(device),
    #             t.float().to(device),
    #             edge_index.long().to(device) if edge_index is not None else None,
    #             returns = returns.to(device),
    #             use_dropout=False,
    #             force_dropout = True,
    #     )
    #         #without return
    #         epsilon_pred_cond = self.forward(
    #             points.float().to(device),
    #             xt.float().to(device),
    #             t.float().to(device),
    #             edge_index.long().to(device) if edge_index is not None else None,
    #             returns = returns.to(device),
    #             use_dropout=False,
    #             force_dropout=False,
    #         )
    #         pred = epsilon_pred_uncond + self.condition_guidance_w*(epsilon_pred_cond - epsilon_pred_uncond)
    #         pred = pred.squeeze(1)
    #         xt = self.gaussian_posterior(target_t, t, pred, xt)
    #         return xt
        



def calculate_gap(cost, true_cost):
        gap = torch.log((cost/true_cost) + 0.01)
        return gap

def on_train_epoch_end(self):
        if self.relabel_epoch>0 and (((self.current_epoch+1) % self.relabel_epoch) == 0):
            self.log("train/relabel_count", self.relabel_count, on_epoch=True, sync_dist=True)
            self.relabel_count = 0
            
def tour2adj(tour, points, sparse, sparse_factor, edge_index):
        if not sparse:
            adj_matrix = torch.zeros((points.shape[0], points.shape[0]))
            for i in range(tour.shape[0] - 1):
                adj_matrix[tour[i], tour[i + 1]] = 1
        else:
            adj_matrix = np.zeros(points.shape[0], dtype=np.int64)
            adj_matrix[tour[:-1]] = tour[1:]
            adj_matrix = torch.from_numpy(adj_matrix)
            adj_matrix = adj_matrix.reshape((-1, 1)).repeat(1, sparse_factor).reshape(-1)
            adj_matrix = torch.eq(edge_index[1].cpu(), adj_matrix).to(torch.int)
        return adj_matrix
    