"""Lightning module for training the DIFUSCO TSP model."""

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from lightning.pytorch.utilities import rank_zero_info
from difusco.co_datasets.tsp_graph_dataset import TSPGraphDataset
from difusco.pl_meta_model import COMetaModel
from difusco.utils.diffusion_schedulers import InferenceSchedule
# from difusco.utils.tsp_utils import TSPEvaluator, batched_two_opt_torch, merge_tours
from difusco.utils.pctsp_utils import PCTSPEvaluator, merge_tours_pctsp
from difusco.pl_pctsp_model import PCTSPModel
import time
import torch._dynamo
torch._dynamo.config.suppress_errors = True

import pdb

class PCTSPModelFreeGuide(PCTSPModel):
    def __init__(self,
                             param_args=None):
        super(PCTSPModelFreeGuide, self).__init__(param_args=param_args)
        self.cost_mean = self.train_dataset.cost_mean
        self.cost_std = self.train_dataset.cost_std
        self.cost_min = self.cost_mean-2*self.cost_std
        self.cost_max = self.cost_mean+2*self.cost_std
        self.relabel_count = 0

    def cost_normalize(self,cost) :
        #0~1
        cost = (cost - self.cost_min)/(self.cost_max-self.cost_min)
        #-1~0
        cost = cost -1
        #cost = (cost - self.cost_mean)/self.cost_std
        return cost
    
    def relabel_dataset(self,x0_pred,xt_,t,points,batch_idx,cost):
        edge_index=None
        if not self.sparse:
            x0_pred_prob = x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
        else:
            x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(dim=-1)
            edge_index = edge_index.cpu().numpy()
        xt = self.categorical_posterior(None, t.long(), x0_pred_prob, xt_,training=True)
        
        adj_mat = xt.float().cpu().detach().numpy() + 1e-6
        tours, merge_iterations = merge_tours_pctsp(
                    adj_mat, points, edge_index,
                    sparse_graph=self.sparse,
                    parallel_sampling=points.shape[0],
                    guided=True
            )
        tsp_solvers = [] # generate tsp_solvers
        for i in range(0, points.shape[0]) :
            tsp_solver = PCTSPEvaluator(points[i])
            wo_2opt_costs = tsp_solver.evaluate(tours[i])
            if wo_2opt_costs < cost[i][0] :
                self.train_dataset.data[batch_idx[i][0]][2] = wo_2opt_costs
                self.train_dataset.data[batch_idx[i][0]][1] = np.array(tours[i])
                self.relabel_count += 1
        
        return None
    
    def forward(self, x, adj, t, edge_index, returns, use_dropout, force_dropout, opt_dropout):
        return self.model(x, t, adj, edge_index,returns, use_dropout, force_dropout, opt_dropout)

    def categorical_training_step(self, batch, batch_idx):
        edge_index = None
        
        if not self.sparse:
            if self.optimality_dropout :
                batch_idx, points, adj_matrix, _, cost, optimality_dropout = batch
            else :
                batch_idx, points, adj_matrix, _, cost = batch
            t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
        else:
            batch_idx, graph_data, point_indicator, edge_indicator, _, cost= batch
            t = np.random.randint(1, self.diffusion.T + 1, point_indicator.shape[0]).astype(int)
            route_edge_flags = graph_data.edge_attr
            points = graph_data.x
            edge_index = graph_data.edge_index
            num_edges = edge_index.shape[1]
            batch_size = point_indicator.shape[0]
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))

        original_cost = cost.clone().detach()

        if self.return_condition and not self.args.cost_category:
            cost = self.cost_normalize(cost)
            
        adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()
        if self.sparse:
            adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

        xt = self.diffusion.sample(adj_matrix_onehot, t)
        if self.relabel_epoch>0:
            xt_ = xt.clone().detach()
        xt = xt * 2 - 1
        xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

        if self.sparse:
            t = torch.from_numpy(t).float()
            t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
            xt = xt.reshape(-1)
            adj_matrix = adj_matrix.reshape(-1)
            points = points.reshape(-1, 2)
            edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
            # xt = xt.reshape(batch_size, -1)
            # adj_matrix = adj_matrix.reshape(batch_size, -1)
            # points = points.reshape(batch_size, -1, 2)
            # edge_index = edge_index.float().to(adj_matrix.device).reshape(batch_size, 2, -1)
        else:
            t = torch.from_numpy(t).float().view(adj_matrix.shape[0])

        # Denoise
        if self.optimality_dropout :
            x0_pred = self.forward(
                points.float().to(adj_matrix.device),
                xt.float().to(adj_matrix.device),
                t.float().to(adj_matrix.device),
                edge_index,
                returns = cost.to(adj_matrix.device),
                use_dropout=False,
                force_dropout = False,
                opt_dropout = optimality_dropout
        )
        else :
            x0_pred = self.forward(
                    points.float().to(adj_matrix.device),
                    xt.float().to(adj_matrix.device),
                    t.float().to(adj_matrix.device),
                    edge_index,
                    returns = cost.to(adj_matrix.device),
                    use_dropout=True,
                    force_dropout = False,
                    opt_dropout=None
                    
            )
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(x0_pred, adj_matrix.long())
        if self.relabel_epoch>0 and (((self.current_epoch+1) % self.relabel_epoch) == 0):
            self.relabel_dataset(x0_pred,xt_,t,points.cpu().numpy(),batch_idx.cpu().numpy(),original_cost.cpu().numpy())
        if self.args.train_unconditional: # add trainable loss
            x0_pred_uncond = self.forward(
                points.float().to(adj_matrix.device),
                xt.float().to(adj_matrix.device),
                t.float().to(adj_matrix.device),
                edge_index,
                returns = cost.to(adj_matrix.device),
                use_dropout=True,
                force_dropout = True,
            )
            loss_uncond = loss_func(x0_pred_uncond, adj_matrix.long())
            loss += loss_uncond
            loss *= 0.5
        # print(x0_pred.shape)
        # print(adj_matrix.shape)
        # Compute loss
    
        self.log("train/loss", loss)
        return loss

    def gaussian_training_step(self, batch, batch_idx):
        if self.sparse:
            # TODO: Implement Gaussian diffusion with sparse graphs
            raise ValueError("DIFUSCO with sparse graphs are not supported for Gaussian diffusion")
        _, points, adj_matrix, _, cost = batch

        if self.return_condition and not self.args.cost_category:
            cost = self.cost_normalize(cost)
        adj_matrix = adj_matrix * 2 - 1
        adj_matrix = adj_matrix * (1.0 + 0.05 * torch.rand_like(adj_matrix))
        # Sample from diffusion
        t = np.random.randint(1, self.diffusion.T + 1, adj_matrix.shape[0]).astype(int)
        xt, epsilon = self.diffusion.sample(adj_matrix, t)

        t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
        # Denoise
        
        #use_dropout = train uncond/cond together
        epsilon_pred = self.forward(
                points.float().to(adj_matrix.device),
                xt.float().to(adj_matrix.device),
                t.float().to(adj_matrix.device),
                edge_index = None,
                returns = cost.to(adj_matrix.device),
                use_dropout=True,
                force_dropout = False,
        )
 

        epsilon_pred = epsilon_pred.squeeze(1)

        # Compute loss
        loss = F.mse_loss(epsilon_pred, epsilon.float())

        if self.args.train_unconditional: # add trainable loss
            epsilon_pred_uncond = self.forward(
                    points.float().to(adj_matrix.device),
                    xt.float().to(adj_matrix.device),
                    t.float().to(adj_matrix.device),
                    edge_index = None,
                    returns = cost.to(adj_matrix.device),
                    use_dropout=True,
                    force_dropout = True,
            )
            loss_uncond = F.mse_loss(epsilon_pred_uncond, adj_matrix.long())
            loss += loss_uncond
            loss *=0.5

        self.log("train/loss", loss)
        return loss



def calculate_gap(cost, true_cost):
    gap = torch.log((cost/true_cost) + 0.01)
    return gap

def on_train_epoch_end(self):
    if self.relabel_epoch>0 and (((self.current_epoch+1) % self.relabel_epoch) == 0):
        self.log("train/relabel_count", self.relabel_count, on_epoch=True, sync_dist=True)
        self.relabel_count = 0
            
def tour2adj(tour, points, sparse, sparse_factor, edge_index):
    if not sparse:
        adj_matrix = torch.zeros((points.shape[0], points.shape[0]))
        for i in range(tour.shape[0] - 1):
            adj_matrix[tour[i], tour[i + 1]] = 1
    else:
        adj_matrix = np.zeros(points.shape[0], dtype=np.int64)
        adj_matrix[tour[:-1]] = tour[1:]
        adj_matrix = torch.from_numpy(adj_matrix)
        adj_matrix = adj_matrix.reshape((-1, 1)).repeat(1, sparse_factor).reshape(-1)
        adj_matrix = torch.eq(edge_index[1].cpu(), adj_matrix).to(torch.int)
    return adj_matrix
    