import torch
import sys
import torch
from DIFUSCO.difusco.utils.diffusion_schedulers import GaussianDiffusion, CategoricalDiffusion, InferenceSchedule
from DIFUSCO.difusco.models.gnn_encoder import GNNEncoder
import numpy as np
import scipy
import argparse
import torch_geometric.utils as pyg_utils

from utils.data_utils import CliqueDataset
import pickle


from torch_geometric.data import Data as GraphData

def cliqueness(indicator,edge_index):
    row, col = edge_index
    weight =  (indicator[row]*indicator[col]).sum()/2.
    num_nodes = indicator.sum()
    if num_nodes==1:
        return True
    max_weight = num_nodes*(num_nodes-1)*0.5
    if (weight/max_weight)>=1.:
        return True
    else:    
        return False

def edge_count(indicator, edge_index):
        row, col = edge_index
        weight =  (indicator[row]*indicator[col]).sum()/2.
        return weight

def cardinality(indicator):
        card = indicator.sum()
        return card
    
def mc_decode_np(x, edge_index):
    if pyg_utils.contains_self_loops(edge_index):
        edge_index = pyg_utils.remove_self_loops(edge_index)[0]  
    indices = x.argsort()[::-1] 
    num_nodes = x.shape[0]
    check_constraint = cliqueness
    objective = edge_count

    solution = torch.zeros(len(x), device=edge_index.device)
    curr_objective = objective(solution, edge_index)
    for k in indices:
        solution[k] = 1.
        
        if objective(solution, edge_index)<curr_objective:
            solution[k] = 0.
        
        else:
            curr_objective = objective(solution, edge_index)
        if not check_constraint(solution, edge_index):
            solution[k] = 0.
        
        
    return cardinality(solution).item()

def fetch_difusco_args():
  parser = argparse.ArgumentParser(description='Train a Pytorch-Lightning diffusion model on a TSP dataset.')
  parser.add_argument('--training_split', type=str, default='data/tsp/tsp50_train_concorde.txt')
  parser.add_argument('--training_split_label_dir', type=str, default=None,
                      help="Directory containing labels for training split (used for MIS).")
  parser.add_argument('--validation_split', type=str, default='data/tsp/tsp50_test_concorde.txt')
  parser.add_argument('--test_split', type=str, default='data/tsp/tsp50_test_concorde.txt')
  parser.add_argument('--validation_examples', type=int, default=8)

  parser.add_argument('--batch_size', type=int, default=64)
  parser.add_argument('--num_epochs', type=int, default=50)
  parser.add_argument('--learning_rate', type=float, default=1e-4)
  parser.add_argument('--weight_decay', type=float, default=0.0)
  parser.add_argument('--lr_scheduler', type=str, default='constant')

  parser.add_argument('--num_workers', type=int, default=16)
  parser.add_argument('--fp16', action='store_true')
  parser.add_argument('--use_activation_checkpoint', action='store_true')

  parser.add_argument('--diffusion_type', type=str, default='categorical')
  parser.add_argument('--diffusion_schedule', type=str, default='linear')
  parser.add_argument('--diffusion_steps', type=int, default=1000)
  parser.add_argument('--inference_diffusion_steps', type=int, default=50)
  parser.add_argument('--inference_schedule', type=str, default='cosine')
  parser.add_argument('--inference_trick', type=str, default="ddim")
  parser.add_argument('--sequential_sampling', type=int, default=1)
  parser.add_argument('--parallel_sampling', type=int, default=1)

  parser.add_argument('--n_layers', type=int, default=12)
  parser.add_argument('--hidden_dim', type=int, default=256)
  parser.add_argument('--sparse_factor', type=int, default=-1)
  parser.add_argument('--aggregation', type=str, default='sum')
  parser.add_argument('--two_opt_iterations', type=int, default=1000)
  parser.add_argument('--save_numpy_heatmap', action='store_true')

  parser.add_argument('--project_name', type=str, default='tsp_diffusion')
  parser.add_argument('--wandb_entity', type=str, default=None)
  parser.add_argument('--wandb_logger_name', type=str, default=None)
  parser.add_argument("--resume_id", type=str, default=None, help="Resume training on wandb.")
  parser.add_argument('--ckpt_path', type=str, default=None)
  parser.add_argument('--resume_weight_only', action='store_true')

  parser.add_argument('--do_train', action='store_true')
  parser.add_argument('--do_test', action='store_true')
  parser.add_argument('--do_valid_only', action='store_true')

  args = parser.parse_args(args=[])
    
  return args



def min_max_norm(x):
    return (x - np.min(x)) / (np.max(x) + 1e-6 - np.min(x))


def harden_probs(x):
    hard_indicators = x 
    hard_indicators[hard_indicators>0.5] = 1
    hard_indicators[hard_indicators<=0.5] = 0
    return hard_indicators

class DIFUSCODataset(CliqueDataset):
    def __init__(self, conf, mode='train',printer=None):
        super(DIFUSCODataset, self).__init__(conf ,mode, printer)
        self.node_level_ground_truth =  pickle.load(open(f'{self.path}/{self.dataset}/clique_labels/{self.mode}.pkl', 'rb'))
        assert len(self.corpus_graphs) == len(self.node_level_ground_truth)
        self.all_graph_data_edge_index = [x.edge_index.cpu() for x in self.graph_data_list]

    def __len__(self):
        return len(self.corpus_graphs)

    def __getitem__(self, idx):
        num_nodes = self.corpus_graph_node_sizes[idx]
        node_labels = self.node_level_ground_truth[idx]
        edge_index = self.all_graph_data_edge_index[idx]
        graph_data = GraphData(x=torch.tensor(node_labels),
                            edge_index=edge_index)

        point_indicator = torch.tensor([num_nodes]).long()
        return (
            torch.LongTensor(np.array([idx], dtype=np.int64)),
            graph_data,
            point_indicator,
        )






class MetaModel(torch.nn.Module):
    def __init__(self, conf, node_feature_only=False):
        super(MetaModel, self).__init__()
        self.conf = conf
        self.args = fetch_difusco_args()
        self.overwrite_revelant_args()
        self.diffusion_type = self.args.diffusion_type 
        self.diffusion_schedule = self.args.diffusion_schedule
        self.diffusion_steps = self.args.diffusion_steps
        self.sparse = self.args.sparse_factor > 0 or node_feature_only
        if self.diffusion_type == 'gaussian':
            out_channels = 1
            self.diffusion = GaussianDiffusion(
                T=self.diffusion_steps, schedule=self.diffusion_schedule)
        elif self.diffusion_type == 'categorical':
            out_channels = 2
            self.diffusion = CategoricalDiffusion(
                T=self.diffusion_steps, schedule=self.diffusion_schedule)
        else:
            raise ValueError(f"Unknown diffusion type {self.diffusion_type}")
        

        self.model = GNNEncoder(
            n_layers=self.args.n_layers,
            hidden_dim=self.args.hidden_dim,
            norm_dim=self.args.norm_dim,
            out_channels=out_channels,
            aggregation=self.args.aggregation,
            sparse=self.sparse,
            use_activation_checkpoint=self.args.use_activation_checkpoint,
            node_feature_only=node_feature_only,
        )
        self.num_training_steps_cached = None

    def overwrite_revelant_args(self):
        self.args.num_epochs = self.conf.training.num_epochs
        self.args.batch_size = self.conf.training.batch_size
        self.args.learning_rate = self.conf.training.learning_rate
        self.args.weight_decay = self.conf.training.weight_decay
        self.args.diffusion_type = self.conf.model.diffusion_type
        self.args.parallel_sampling = self.conf.model.parallel_sampling
        self.args.norm_dim = 32
        if self.conf.model.EQ:
            self.args.n_layers = 5
            self.args.hidden_dim = 10
            self.args.norm_dim = 10
        self.args.num_workers = 1
        self.device = self.conf.training.device
        



    def get_total_num_training_steps(self, dataset) -> int:
        """Total training steps inferred from datamodule and devices."""
        if self.num_training_steps_cached is not None:
            return self.num_training_steps_cached


        dataset_size = len(dataset)
        num_devices = 1
        effective_batch_size = 1 * num_devices
        self.num_training_steps_cached = (dataset_size // effective_batch_size) * 50
        return self.num_training_steps_cached



class DIFUSCO_MCModel(MetaModel):
    def __init__(self, conf):
        super(DIFUSCO_MCModel, self).__init__(conf, node_feature_only=True)

    def forward_gnn(self, x, t, edge_index):
        return self.model(x, t, edge_index=edge_index)
    
    def forward(self, batch, batch_idx):
        if self.diffusion_type == 'gaussian':
            return self.gaussian_forward(batch, batch_idx)
        elif self.diffusion_type == 'categorical':
            return self.categorical_forward(batch, batch_idx)
        else:
            raise NotImplementedError(f"Unknown diffusion type {self.diffusion_type}")
    
    
    def gaussian_forward(self, batch, batch_idx):
        _, graph_data, point_indicator = batch
        t = np.random.randint(1, self.diffusion.T + 1, point_indicator.shape[0]).astype(int)
        node_labels = graph_data.x
        edge_index = graph_data.edge_index
        device = node_labels.device

        # Sample from diffusion
        node_labels = node_labels.float() * 2 - 1
        node_labels = node_labels * (1.0 + 0.05 * torch.rand_like(node_labels))
        node_labels = node_labels.unsqueeze(1).unsqueeze(1)

        t = torch.from_numpy(t).long()

        t = t.repeat_interleave(point_indicator.reshape(-1).cpu(), dim=0).numpy()


        xt, epsilon = self.diffusion.sample(node_labels, t)

        t = torch.from_numpy(t).float()
        t = t.reshape(-1)
        xt = xt.reshape(-1)
        edge_index = edge_index.to(device).reshape(2, -1)
        epsilon = epsilon.reshape(-1)

        # Denoise
        epsilon_pred = self.forward_gnn(
            xt.float().to(device),
            t.float().to(device),
            edge_index,
        )
        epsilon_pred = epsilon_pred.squeeze(1)
        return torch.nn.functional.mse_loss(epsilon_pred, epsilon)

    def categorical_forward(self, batch, batch_idx):
        _, graph_data, point_indicator = batch
        t = np.random.randint(1, self.diffusion.T + 1, point_indicator.shape[0]).astype(int)
        node_labels = graph_data.x
        edge_index = graph_data.edge_index

        # Sample from diffusion
        node_labels_onehot = torch.nn.functional.one_hot(node_labels.long(), num_classes=2).float()
        node_labels_onehot = node_labels_onehot.unsqueeze(1).unsqueeze(1)

        t = torch.from_numpy(t).long()
        t = t.repeat_interleave(point_indicator.reshape(-1).cpu(), dim=0).numpy()

        xt = self.diffusion.sample(node_labels_onehot, t)
        xt = xt * 2 - 1
        xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

        t = torch.from_numpy(t).float()
        t = t.reshape(-1)
        xt = xt.reshape(-1)
        edge_index = edge_index.to(node_labels.device).reshape(2, -1)

        # Denoise
        x0_pred = self.forward_gnn(
            xt.float().to(node_labels.device),
            t.float().to(node_labels.device),
            edge_index,
        )

        loss_func = torch.nn.CrossEntropyLoss()

        loss = loss_func(x0_pred, node_labels.long())
        return loss    
    

    def categorical_training_step(self, batch, batch_idx):
        _, graph_data, point_indicator = batch
        t = np.random.randint(1, self.diffusion.T + 1, point_indicator.shape[0]).astype(int)
        node_labels = graph_data.x
        edge_index = graph_data.edge_index

        # Sample from diffusion
        node_labels_onehot = torch.nn.functional.one_hot(node_labels.long(), num_classes=2).float()
        node_labels_onehot = node_labels_onehot.unsqueeze(1).unsqueeze(1)

        t = torch.from_numpy(t).long()
        t = t.repeat_interleave(point_indicator.reshape(-1).cpu(), dim=0).numpy()

        xt = self.diffusion.sample(node_labels_onehot, t)
        xt = xt * 2 - 1
        xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

        t = torch.from_numpy(t).float()
        t = t.reshape(-1)
        xt = xt.reshape(-1)
        edge_index = edge_index.to(node_labels.device).reshape(2, -1)

        # Denoise
        x0_pred = self.forward(
            xt.float().to(node_labels.device),
            t.float().to(node_labels.device),
            edge_index,
        )

        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(x0_pred, node_labels)
        return loss
    
    
    def gaussian_denoise_step(self, xt, t, device, edge_index=None, target_t=None):
        with torch.no_grad():
            t = torch.from_numpy(t).view(1)
            pred = self.forward_gnn(
                xt.float().to(device),
                t.float().to(device),
                edge_index.long().to(device) if edge_index is not None else None,
            )
            pred = pred.squeeze(1)
            xt = self.gaussian_posterior(target_t, t, pred, xt)
            return xt
        
    def duplicate_edge_index(self, edge_index, num_nodes, device):
        """Duplicate the edge index (in sparse graphs) for parallel sampling."""
        edge_index = edge_index.reshape((2, 1, -1))
        edge_index_indent = torch.arange(0, self.args.parallel_sampling).view(1, -1, 1).to(device)
        edge_index_indent = edge_index_indent * num_nodes
        edge_index = edge_index + edge_index_indent
        edge_index = edge_index.reshape((2, -1))
        return edge_index
    
    def gaussian_posterior(self, target_t, t, pred, xt):
        """Sample (or deterministically denoise) from the Gaussian posterior for a given time step.
        See https://arxiv.org/pdf/2010.02502.pdf for details.
        """
        diffusion = self.diffusion
        if target_t is None:
            target_t = t - 1
        else:
            target_t = torch.from_numpy(target_t).view(1)

        atbar = diffusion.alphabar[t]
        atbar_target = diffusion.alphabar[target_t]

        if self.args.inference_trick is None or t <= 1:
        # Use DDPM posterior
            at = diffusion.alpha[t]
            z = torch.randn_like(xt)
            atbar_prev = diffusion.alphabar[t - 1]
            beta_tilde = diffusion.beta[t - 1] * (1 - atbar_prev) / (1 - atbar)

            xt_target = (1 / np.sqrt(at)).item() * (xt - ((1 - at) / np.sqrt(1 - atbar)).item() * pred)
            xt_target = xt_target + np.sqrt(beta_tilde).item() * z
        elif self.args.inference_trick == 'ddim':
            xt_target = np.sqrt(atbar_target / atbar).item() * (xt - np.sqrt(1 - atbar).item() * pred)
            xt_target = xt_target + np.sqrt(1 - atbar_target).item() * pred
        else:
            raise ValueError('Unknown inference trick {}'.format(self.args.inference_trick))
        return xt_target
        

    def test_step(self, batch, batch_idx, draw=False, split='test'):
        device = batch[-1].device

        real_batch_idx, graph_data, point_indicator = batch
        node_labels = graph_data.x
        edge_index = graph_data.edge_index

        stacked_predict_labels = []
        edge_index = edge_index.to(node_labels.device).reshape(2, -1)
        edge_index_np = edge_index.cpu().numpy()
        adj_mat = scipy.sparse.coo_matrix(
            (np.ones_like(edge_index_np[0]), (edge_index_np[0], edge_index_np[1])),
        )

        for _ in range(self.args.sequential_sampling):
            xt = torch.randn_like(node_labels.float())
            if self.args.parallel_sampling > 1:
                xt = xt.repeat(self.args.parallel_sampling, 1, 1)
                xt = torch.randn_like(xt)

            if self.diffusion_type == 'gaussian':
                xt.requires_grad = True
            else:
                xt = (xt > 0).long()
            xt = xt.reshape(-1)

            if self.args.parallel_sampling > 1:
                edge_index = self.duplicate_edge_index(edge_index, node_labels.shape[0], device)

            batch_size = 1
            steps = self.args.inference_diffusion_steps
            time_schedule = InferenceSchedule(inference_schedule=self.args.inference_schedule,
                                                T=self.diffusion.T, inference_T=steps)

            for i in range(steps):
                t1, t2 = time_schedule(i)
                t1 = np.array([t1 for _ in range(batch_size)]).astype(int)
                t2 = np.array([t2 for _ in range(batch_size)]).astype(int)

                if self.diffusion_type == 'gaussian':
                    xt = self.gaussian_denoise_step(
                        xt, t1, device, edge_index, target_t=t2)
                else:
                    xt = self.categorical_denoise_step(
                        xt, t1, device, edge_index, target_t=t2)

            if self.diffusion_type == 'gaussian':
                predict_labels = xt.float().cpu().detach().numpy() * 0.5 + 0.5
            else:
                predict_labels = xt.float().cpu().detach().numpy() + 1e-6
            stacked_predict_labels.append(predict_labels)

        predict_labels = np.concatenate(stacked_predict_labels, axis=0)
        all_sampling = self.args.sequential_sampling * self.args.parallel_sampling

        splitted_predict_labels = np.split(predict_labels, all_sampling)
        solved_solutions = [mc_decode_np(predict_labels, graph_data.edge_index) for predict_labels in splitted_predict_labels]
        solved_costs = [solved_solution for solved_solution in solved_solutions]
        best_solved_cost = np.max(solved_costs)
        

        
        if self.diffusion_type=="gaussian":
            probs = [ min_max_norm(predict_labels) for predict_labels in splitted_predict_labels]
        else:
            probs = splitted_predict_labels

        hard_probs = [harden_probs(p) for p in probs]

        non_decoder_solutions = [p.sum() for p in hard_probs]

        best_non_decoder_sol = np.max(non_decoder_solutions)

        return best_solved_cost, best_non_decoder_sol
    
    def categorical_posterior(self, target_t, t, x0_pred_prob, xt):
        """Sample from the categorical posterior for a given time step.
        See https://arxiv.org/pdf/2107.03006.pdf for details.
        """
        diffusion = self.diffusion

        if target_t is None:
            target_t = t - 1
        else:
            target_t = torch.from_numpy(target_t).view(1)

        if target_t > 0:
            Q_t = np.linalg.inv(diffusion.Q_bar[target_t]) @ diffusion.Q_bar[t]
            Q_t = torch.from_numpy(Q_t).float().to(x0_pred_prob.device)
        else:
            Q_t = torch.eye(2).float().to(x0_pred_prob.device)
        Q_bar_t_source = torch.from_numpy(diffusion.Q_bar[t]).float().to(x0_pred_prob.device)
        Q_bar_t_target = torch.from_numpy(diffusion.Q_bar[target_t]).float().to(x0_pred_prob.device)

        xt = torch.nn.functional.one_hot(xt.long(), num_classes=2).float()
        xt = xt.reshape(x0_pred_prob.shape)

        x_t_target_prob_part_1 = torch.matmul(xt, Q_t.permute((1, 0)).contiguous())
        x_t_target_prob_part_2 = Q_bar_t_target[0]
        x_t_target_prob_part_3 = (Q_bar_t_source[0] * xt).sum(dim=-1, keepdim=True)

        x_t_target_prob = (x_t_target_prob_part_1 * x_t_target_prob_part_2) / x_t_target_prob_part_3

        sum_x_t_target_prob = x_t_target_prob[..., 1] * x0_pred_prob[..., 0]
        x_t_target_prob_part_2_new = Q_bar_t_target[1]
        x_t_target_prob_part_3_new = (Q_bar_t_source[1] * xt).sum(dim=-1, keepdim=True)

        x_t_source_prob_new = (x_t_target_prob_part_1 * x_t_target_prob_part_2_new) / x_t_target_prob_part_3_new

        sum_x_t_target_prob += x_t_source_prob_new[..., 1] * x0_pred_prob[..., 1]

        if target_t > 0:
            xt = torch.bernoulli(sum_x_t_target_prob.clamp(0, 1))
        else:
            xt = sum_x_t_target_prob.clamp(min=0)

        if self.sparse:
            xt = xt.reshape(-1)
        return xt
    
    def categorical_denoise_step(self, xt, t, device, edge_index=None, target_t=None):
        with torch.no_grad():
            t = torch.from_numpy(t).view(1)
            x0_pred = self.forward_gnn(
                xt.float().to(device),
                t.float().to(device),
                edge_index.long().to(device) if edge_index is not None else None,
            )
            x0_pred_prob = x0_pred.reshape((1, xt.shape[0], -1, 2)).softmax(dim=-1)
            xt = self.categorical_posterior(target_t, t, x0_pred_prob, xt)
            return xt
            

