import logging
import os
import time
from typing import Any, Dict

import torch
import torch_sparse

from torch_geometric.graphgym.checkpoint import (
    clean_ckpt,
    load_ckpt,
    save_ckpt,
)
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.register import register_train
from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch
from torch_geometric.data import Data, Batch
from tqdm import tqdm
import torch.nn.functional as F
from ..utils import flatten_dict, make_wandb_name, cfg_to_dict, reset_slice_dict_edges, add_full_rrwp, nodes_per_batch_sample
from ..inference_utils import eval_inference, eval_inference_sizing
import numpy as np
from torch.distributions.categorical import Categorical

import torch
import torch.nn.functional as F
import numpy as np
from scipy.optimize import linear_sum_assignment


def train_epoch(logger, loader, model, optimizer, scheduler):
    model.train()
    model.to(torch.device(cfg.device))
    time_start = time.time()
    losses = []
    for idx, batch in enumerate(loader):
        batch = add_prunable_nodes(batch)
        batch.to(torch.device(cfg.device))
        # Count number of nodes per sample
        npbs = nodes_per_batch_sample(batch)
        
        optimizer.zero_grad()

        # Sample time
        t_x, t_e, t_f = sample_t(batch, npbs)

        # Sample x_t
        noised_batch = noise_batch(batch.clone(), npbs, t_x, t_e, t_f)

        # Add RRWP & update batch t_x and t_e
        noised_batch = add_full_rrwp(noised_batch.clone(), walk_length=cfg.posenc_RRWP.ksteps)
        noised_batch.t_x = t_x
        noised_batch.t_f = t_f
        # Account for lower triangular indices of A
        noised_batch.t_e = t_e.repeat(2) # This way t_e is ordered identically to out_idx in RRWPLinearEdgeEncoder

        # Denoising 
        denoised_batch = model(noised_batch.clone(), unconditional_prop=cfg.train.ratio_cf_guidance)

        # Loss
        loss, loss_topo, loss_sizing = dfm_loss(batch, denoised_batch)
        if not torch.isnan(loss):
            losses.append(loss.item())
        loss.backward()
        optimizer.step()
        logger.update_stats(true=batch.detach().cpu(),
                            pred=denoised_batch.detach().cpu(), loss=loss_topo.item(),
                            loss_sizing=loss_sizing.item(),
                            lr=scheduler.get_last_lr()[0],
                            time_used=time.time() - time_start,
                            params=cfg.params)
        time_start = time.time()
    scheduler.step()
    return sum(losses)/len(losses)


def eval_epoch(logger, loader, model, sizing_mode=False):

    model.eval()
    model.to(torch.device(cfg.device))
    time_start = time.time()
    losses, sim_out = [], []
    for idx, batch in enumerate(loader):
        batch = add_prunable_nodes(batch)
        batch.to(torch.device(cfg.device))

        # If we learn to denoise only, then directly run the full denoising inference.
        if sizing_mode:
            sim_out.append(eval_inference_sizing(model, batch.clone(), euler_steps=20, n_pow=1))
            logger.update_stats(true=batch.detach().cpu(), pred=batch.detach().cpu(), loss=0.0, time_used=time.time() - time_start,
                                params=cfg.params)
            time_start = time.time()
            continue

        # Count number of nodes per sample
        npbs = nodes_per_batch_sample(batch)

        # Sample time
        t_x, t_e, t_f = sample_t(batch, npbs)

        # Sample x_t
        noised_batch = noise_batch(batch.clone(), npbs, t_x, t_e, t_f)

        # Add rrwp & update batch t_x and t_e
        noised_batch = add_full_rrwp(noised_batch, walk_length=cfg.posenc_RRWP.ksteps)
        noised_batch.t_x = t_x
        noised_batch.t_f = t_f
        # Account for lower triangular indices of A
        noised_batch.t_e = t_e.repeat(2) # This way t_e is ordered identically to out_idx in RRWPLinearEdgeEncoder
        
        # Denoising 
        denoised_batch = model(noised_batch.clone())
        loss, loss_topo, loss_sizing = dfm_loss(batch, denoised_batch)
        losses.append(loss.item())

        logger.update_stats(true=batch.detach().cpu(),
                            pred=denoised_batch.detach().cpu(), loss=loss_topo.item(),
                            loss_sizing=loss_sizing.item(),
                            lr=0,
                            time_used=time.time() - time_start,
                            params=cfg.params)
        time_start = time.time()

    if sizing_mode:
        sim_out = np.stack(sim_out).mean(axis=0)
        reformat = lambda x: np.round(float(x), 4)
        return 0.0, {
            "Avg Gain" : reformat(sim_out[0]),
            "Avg Phase Margin" : reformat(sim_out[1]),
            "Avg UGain Freq" : reformat(sim_out[2])
        }

    return sum(losses)/len(losses), None


def add_prunable_nodes(batch):

    if cfg.gt.node_pruning == 0:
        return batch

    updated_graphs = []

    for i in range(batch.num_graphs):
        graph = batch.get_example(i).clone()

        if np.random.randint(2) == 1: # Add nodes 50% of the time

            # Randomly decide how many extra nodes to add, max arbitrarily set to 5
            # n_extra_nodes = (1 + np.abs(np.random.randn()) * 2).astype(int).clip(max=5)
            n_extra_nodes = max(0, min(30 - graph.num_nodes, np.random.randint(5, 15)))

            # Concat node type with features and append to node list
            if cfg.gt.node_pruning == 1:
                choices = [0, 1, 2, 3, 4, 5, 10]
                if cfg.dataset.use_pins:
                    choices.extend([6, 7])
                extra_nodes_types = torch.from_numpy(np.random.choice(choices, n_extra_nodes, replace=True))[:, None]
            else:
                extra_nodes_types = torch.full((n_extra_nodes, 1), fill_value=cfg.dataset.nnode_types - 1)

            # if cfg.dataset.node_features_dim > 1:
            # Generate features randomly
            rdm_feats = np.random.choice(cfg.dataset.nnode_features, n_extra_nodes, replace=True).reshape(-1, 1)
            extra_nodes_feats = torch.from_numpy(rdm_feats)
            # extra_nodes = torch.cat([extra_nodes, torch.tensor(rdm_feats)], dim=1)

            # Add a flag on added nodes for the loss (only on edges if node_pruning is 1, on both edges and node types if it's 2)
            graph.prunable = torch.cat([graph.x.new_zeros(graph.num_nodes), graph.x.new_ones(n_extra_nodes)])

            graph.x = torch.cat([graph.x, extra_nodes_types.to(graph.x.device)], dim=0)
            graph.x_features = torch.cat([graph.x_features, extra_nodes_feats.to(graph.x.device)], dim=0)

            # Update triu_edge_index
            num_nodes = len(graph.x)
            row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
            all_connections = torch.stack((row, col), dim=0)
            graph.triu_edge_index = all_connections.to(graph.x.device)

        else:
            graph.prunable = graph.x.new_zeros(graph.num_nodes)

        updated_graphs.append(graph)
        
    batch = Batch.from_data_list(updated_graphs)
    return batch


@register_train('dfm')
def train_example(loggers, loaders, model, optimizer, scheduler):
    start_epoch = 0
    if cfg.train.auto_resume:
        start_epoch = load_ckpt(model, optimizer, scheduler,
                                cfg.train.epoch_resume)
    if start_epoch == cfg.optim.max_epoch:
        logging.info('Checkpoint found, Task already done')
    else:
        logging.info('Start from epoch %s', start_epoch)

    if cfg.wandb.use:
        try:
            import wandb
        except:
            raise ImportError('WandB is not installed.')
        if cfg.wandb.name == '':
            wandb_name = make_wandb_name(cfg)
        else:
            wandb_name = cfg.wandb.name
        run = wandb.init(entity=cfg.wandb.entity, project=cfg.wandb.project,
                         name=wandb_name)
        run.config.update(cfg_to_dict(cfg))

    num_splits = len(loggers)
    perf = [[] for _ in range(num_splits)]
    full_epoch_times = []

    with tqdm(range(start_epoch, cfg.optim.max_epoch), desc="Training") as pbar:
        for cur_epoch in pbar:
            start_time = time.perf_counter()
            train_loss = train_epoch(loggers[0], loaders[0], model, optimizer, scheduler)
            perf[0].append(loggers[0].write_epoch(cur_epoch, False))

            if is_eval_epoch(cur_epoch):

                for i in range(1, num_splits):
                    sizing_mode = False # cfg.train.noise_feat_only and i == 2
                    eval_loss, sim_out = eval_epoch(loggers[i], loaders[i], model, sizing_mode=sizing_mode)
                    perf[i].append(loggers[i].write_epoch(cur_epoch, custom_metrics=not sizing_mode and cfg.dataset.name != "AnalogGenie"))
                    if sizing_mode:
                        perf[i][-1].update(sim_out)

                # Add inference stats in test metrics dict
                if (cfg.dataset.name != "Analogenie"): # and (not cfg.train.noise_feat_only):
                    inf_kwargs = {'num_samples': 200, 'euler_steps': 20, 'noise': 0, 'n_pow_e': cfg.train.distortion_pow_e, 
                                  'n_pow_x': cfg.train.distortion_pow_n}
                    y_test = loaders[-1].dataset.y if cfg.gt.conditional_gen else None
                    cond_y = y_test[np.random.choice(np.arange(len(y_test)), inf_kwargs['num_samples'])] if y_test is not None else None
                    inf_kwargs.update({'cond_y': cond_y})
                    if cfg.dataset.name != "AnalogGenie":
                        perf[-1][-1].update({f'inference_{k}': v for (k, v) in eval_inference(model, **inf_kwargs).items()})

            val_perf = perf[1]
            full_epoch_times.append(time.perf_counter() - start_time)

            pbar.set_postfix(train_loss=train_loss, eval_loss=eval_loss)
            if is_ckpt_epoch(cur_epoch):
                save_ckpt(model, optimizer, scheduler, cur_epoch)

            if cfg.wandb.use:
                run.log(flatten_dict(perf), step=cur_epoch)

            # Log current best stats on eval epoch.
            if is_eval_epoch(cur_epoch):
                best_epoch = np.array([vp['loss'] for vp in val_perf]).argmin()
                best_epoch_loss = best_epoch

                best_train = best_val = best_test = ""
                if cfg.metric_best != 'auto':
                    # Select again based on val perf of `cfg.metric_best`.
                    m = cfg.metric_best
                    best_epoch = getattr(np.array([vp[m] for vp in val_perf]),
                                        cfg.metric_agg)()

                    if cfg.best_by_loss:
                        best_epoch = best_epoch_loss

                    if m in perf[0][best_epoch]:
                        best_train = f"train_{m}: {perf[0][best_epoch][m]:.4f}"
                    else:
                        best_train = f"train_{m}: {0:.4f}"
                    best_val = f"val_{m}: {perf[1][best_epoch][m]:.4f}"
                    best_test = f"test_{m}: {perf[2][best_epoch][m]:.4f}"

                    if cfg.wandb.use:
                        bstats = {"best/epoch": best_epoch}
                        for i, s in enumerate(['train', 'val', 'test']):
                            bstats[f"best/{s}_loss"] = perf[i][best_epoch]['loss']
                            if m in perf[i][best_epoch]:
                                bstats[f"best/{s}_{m}"] = perf[i][best_epoch][m]

                                run.summary[f"best_{s}_perf"] = \
                                    perf[i][best_epoch][m]


                            for x in ['hits@1', 'hits@3', 'hits@10', 'mrr']:
                                if x in perf[i][best_epoch]:
                                    bstats[f"best/{s}_{x}"] = perf[i][best_epoch][x]

                        run.log(bstats, step=cur_epoch)
                        run.summary["full_epoch_time_avg"] = np.mean(full_epoch_times)
                        run.summary["full_epoch_time_sum"] = np.sum(full_epoch_times)

                # Checkpoint the best epoch params (if enabled).
                if cfg.train.enable_ckpt and cfg.train.ckpt_best and \
                        best_epoch == cur_epoch:
                    if cur_epoch < cfg.optim.num_warmup_epochs:
                        pass
                    else:
                        save_ckpt(model, optimizer, scheduler, cur_epoch)
                    if cfg.train.ckpt_clean:  # Delete old ckpt each time.
                        clean_ckpt()
                logging.info(
                    f"> Epoch {cur_epoch}: took {full_epoch_times[-1]:.1f}s "
                    f"(avg {np.mean(full_epoch_times):.1f}s) | "
                    f"Best so far: epoch {best_epoch}\t"
                    f"train_loss: {perf[0][best_epoch]['loss']:.4f} {best_train}\t"
                    f"val_loss: {perf[1][best_epoch]['loss']:.4f} {best_val}\t" 
                    f"test_loss: {perf[2][best_epoch]['loss']:.4f} {best_test}\n"
                    f"-----------------------------------------------------------"
                )
                
    for logger in loggers:
        logger.close()
    if cfg.train.ckpt_clean:
        clean_ckpt()

    if cfg.train.save_final_model:
        ckpt: Dict[str, Any] = {}
        ckpt["model_state"] = model.state_dict()
        if optimizer is not None:
            ckpt["optimizer_state"] = optimizer.state_dict()
        if scheduler is not None:
            ckpt["scheduler_state"] = scheduler.state_dict()

        os.makedirs(cfg.out_dir, exist_ok=True)
        # if cfg.wandb.use:
        #     date = time.strftime("%Y%m%d-%H%M%S")
        #     torch.save(ckpt, cfg.out_dir + f"/{run.id}--{date}.ckpt")
        # else:
        date = time.strftime("%Y%m%d-%H%M%S")
        torch.save(ckpt, cfg.out_dir + f"/{date}_last.ckpt")

    logging.info('Task done, results saved in %s', cfg.run_dir)


# def sample_edge_time(nodes_for_graph, t, eval=False, sample_same_t=False):
#     edges_for_graph = (nodes_for_graph * (nodes_for_graph - 1)) // 2

#     if not eval and cfg.gt.sample_separate_t and not sample_same_t:
#         t_e = torch.rand((edges_for_graph.sum(),)).to(torch.device(cfg.device))
#         if cfg.train.t_sample_distortion == 'pow':
#             t_e = 1 - ((1 - t_e) ** cfg.train.distortion_pow)
#         elif cfg.train.t_sample_distortion == 'norm':
#             t_e = (torch.randn(edges_for_graph.sum(),) * 0.2 + 0.4).clip(min=1e-5, max=1 - 1e-5).to(torch.device(cfg.device))
#     else:
#         # Repeat each element of t according to the corresponding value in edge_for_graph
#         t_e = t.repeat_interleave(edges_for_graph)
    
#     return t_e


def sample_edge_time(nodes_for_graph, batch, t):

    edges_for_graph = (nodes_for_graph * (nodes_for_graph - 1)) // 2
    # T_e is either sampled independently for each edge (50% of the time when `sample_separate_t`), or for each graph
    sep_sampling = (t is None) & cfg.gt.sample_separate_t & (np.random.randint(2) == 0)
    t_sample_size = edges_for_graph.sum() if sep_sampling else batch.num_graphs

    t_e = torch.rand((t_sample_size,)).to(torch.device(cfg.device)) if t is None else t

    if t is not None:
        if cfg.train.t_sample_distortion_e == 'pow':
            t_e = 1 - ((1 - t_e) ** cfg.train.distortion_pow_e)
        elif cfg.train.t_sample_distortion_e == 'norm':
            t_e = (torch.randn(edges_for_graph.sum(),) * 0.2 + 0.4).clip(min=1e-5, max=1 - 1e-5).to(torch.device(cfg.device))
        
    if not sep_sampling:
        t_e = t_e.repeat_interleave(edges_for_graph)
    
    return t_e


def sample_node_time(batch, t):

    sep_sampling = (t is None) & cfg.gt.sample_separate_t & (np.random.randint(2) == 0)
    t_sample_size = batch.num_nodes if sep_sampling else batch.num_graphs

    t_x = torch.rand((t_sample_size,)).to(torch.device(cfg.device)) if t is None else t

    if t is not None:
        if cfg.train.t_sample_distortion_n == 'pow':
            t_x = 1 - ((1 - t_x) ** cfg.train.distortion_pow_n)
        elif cfg.train.t_sample_distortion_n == 'norm':
            t_x = (torch.randn(t_sample_size,) * 0.2 + 0.4).clip(min=1e-5, max=1 - 1e-5).to(torch.device(cfg.device))

    # If one t is sampled for each graph, expand it over nodes
    if not sep_sampling:
        t_x = t_x[batch.batch]

    return t_x


def sample_feat_time(batch, t): # Exact same function as above, only change is the time distortion condition.

    sep_sampling = (t is None) & cfg.gt.sample_separate_t & (np.random.randint(2) == 0)
    t_sample_size = batch.num_nodes if sep_sampling else batch.num_graphs

    t_f = torch.rand((t_sample_size,)).to(torch.device(cfg.device)) if t is None else t

    if t is not None:
        if hasattr(cfg.train, 't_sample_distortion_f') and (cfg.train.t_sample_distortion_f == 'pow'):
            t_f = 1 - ((1 - t_f) ** cfg.train.distortion_pow_f)
        elif hasattr(cfg.train, 't_sample_distortion_f') and (cfg.train.t_sample_distortion_f == 'norm'):
            t_f = (torch.randn(t_sample_size,) * 0.2 + 0.4).clip(min=1e-5, max=1 - 1e-5).to(torch.device(cfg.device))

    # If one t is sampled for each graph, expand it over nodes
    if not sep_sampling:
        t_f = t_f[batch.batch]

    return t_f


def sample_t(batch, npbs):
    
    # Maybe sample one t per graph
    graph_sampling = np.random.randint(4) == 0
    graph_t = torch.rand((batch.num_graphs,)).to(torch.device(cfg.device)) if graph_sampling else None
    t_x = sample_node_time(batch, graph_t)
    t_e = sample_edge_time(npbs, batch, graph_t)
    t_f = sample_feat_time(batch, graph_t)

    if cfg.dataset.get("task_type", '') == 'pin_prediction':
        t_e = t_e * batch.triu_learnable_edge_attr + t_e.new_ones(len(t_e)) * (1 - batch.triu_learnable_edge_attr)
        t_x = t_x.new_ones(len(t_x))
        t_f = t_f.new_ones(len(t_f))

    return t_x, t_e, t_f


def noising_node(batch, t_x):

    if cfg.framework.type == 'vfm':
        return noising_node_vfm(batch, t_x)
    elif cfg.train.prior == 'marginal':
        return noising_node_marginal(batch, t_x)
    elif cfg.train.prior == 'masked':
        return noising_node_mask(batch, t_x)
    

def noising_edge(batch, nodes_for_graph, t_e):

    # if cfg.train.noise_feat_only:
    #     return batch
    
    if cfg.framework.type == 'vfm':
        return noising_edge_vfm(batch, t_e)
    elif cfg.train.prior == 'marginal':
        return noising_edge_marginal(batch, t_e)
    elif cfg.train.prior == 'masked':
        return noising_edge_mask(batch, nodes_for_graph, t_e)
    

def noise_batch(batch, nodes_per_graph, t_x, t_e, t_f):

    # Noise node types
    batch = noising_node(batch, t_x)
    # Noise edges
    batch = noising_edge(batch, nodes_per_graph, t_e)
    # Noise node features
    batch = noising_node_features(batch, t_f)

    return batch


# def broadcast_t_e(batch, t_e):

#     # Before noising, broadcast
#     t_e_idx, t_e_val = torch_sparse.coalesce(
#         torch.cat([batch.triu_edge_index, torch.flip(batch.triu_edge_index, dims=[0])], dim=1), 
#         torch.cat([t_e, t_e.new_zeros(t_e.shape)]), batch.num_nodes, batch.num_nodes,
#         op="add"
#     )

#     # t_e_idx is the same as out_idx in RRWPLinearEdgeEncoder

#     batch.t_e = t_e_val

#     return batch
    

def noising_node_vfm(batch, t_x):

    # if not cfg.train.noise_feat_only:
    
    # Standard Gaussian noise distribution: x0 (and xt) **do not** belong to the probability simplex
    x0 = torch.randn((len(batch.x), cfg.dataset.nnode_types), device=cfg.device) * 2 + 1

    logits = t_x[:, None] * F.one_hot(batch.x[:, 0], num_classes=cfg.dataset.nnode_types) \
            + x0 * (1 - t_x)[:, None] # Shape (batch.num_nodes, n_classes)
    # Node types are represented as real values in the vfm framework
    batch.xt_logits = logits

    # # Device sizes
    # batch = noising_node_features(batch, t_x)
    
    return batch
    

def noising_node_marginal(batch, t_x):

    # if not cfg.train.noise_feat_only:

    # Interpolate pmf between GT and marginals, then sample xt.
    probs = t_x[:, None] * F.one_hot(batch.x[:, 0], num_classes=cfg.dataset.nnode_types) \
            + torch.tensor(cfg.node_type_pmf).repeat(len(batch.x), 1).to(cfg.device) * (1 - t_x)[:, None] # Shape (batch.num_nodes, n_classes)
    categorical_dist = Categorical(probs=probs)
    node_type_samples = categorical_dist.sample((1,)).squeeze()

    # batch.x[:, 0] = node_type_samples
    batch.x = node_type_samples[:, None] # Shape (batch.num_nodes, 1)

    # # Device sizes
    # batch = noising_node_features(batch, t_x)
  
    return batch


def noising_node_features(batch, t_f):

    # if cfg.dataset.node_features_dim == 1:
    if not cfg.gt.sizing:
        return batch

    batch.x_features = batch.x_features.float()
    x_features_0 = torch.rand((len(batch.x_features), 1), device=cfg.device) * cfg.dataset.nnode_features
    # batch.x[:, 1:] = (batch.x[:, 1:] * t_f[:, None] + x_features_0 * (1 - t_f[:, None])).clip(min=1e-1)
    batch.x_features = (batch.x_features * t_f[:, None] + x_features_0 * (1 - t_f[:, None])).clip(min=1e-1)

    return batch


def noising_edge_vfm(batch, t_e):

    # Densify directed edges
    one_way_edge_idx = batch.edge_index[0] < batch.edge_index[1]
    xt_edge_idx, xt_edge_attr = torch_sparse.coalesce(
        torch.cat([batch.edge_index[:, one_way_edge_idx], batch.triu_edge_index], dim=1),
        torch.cat([batch.edge_attr[one_way_edge_idx], batch.edge_attr.new_zeros(batch.triu_edge_index.size(1))], dim=0),
        batch.num_nodes, batch.num_nodes,
        op="max"
    )

    xt_edge_attr = F.one_hot(xt_edge_attr, num_classes=2) * t_e[:, None] \
            + (1 - t_e)[:, None] * (torch.randn((len(t_e), 2))  * 2 + 1).to(cfg.device) # edge_attr are now real values in [0, 1]

    # Flip and concat with the same ordering
    xt_edge_idx = xt_edge_idx.repeat_interleave(2, dim=1)
    xt_edge_idx[:, ::2] = torch.flip(xt_edge_idx[:, ::2], dims=[0])
    xt_edge_attr = xt_edge_attr.repeat_interleave(2, dim=0)

    batch.edge_index = xt_edge_idx
    batch.edge_attr = xt_edge_attr # Contrary to the discrete FM case, here edge_attr has shape (num_edges, 2)

    # Finally, update _slice_dict as the space occupied by each sample in the batch has been changed
    batch = reset_slice_dict_edges(batch)

    return batch


def noising_edge_marginal(batch, t_e):

    # Densify directed edges
    one_way_edge_idx = batch.edge_index[0] < batch.edge_index[1]
    xt_edge_idx, xt_edge_attr = torch_sparse.coalesce(
        torch.cat([batch.edge_index[:, one_way_edge_idx], batch.triu_edge_index], dim=1),
        torch.cat([batch.edge_attr[one_way_edge_idx], batch.edge_attr.new_zeros(batch.triu_edge_index.size(1))], dim=0),
        batch.num_nodes, batch.num_nodes,
        op="max"
    )

    probs = xt_edge_attr * t_e + (1 - t_e) * cfg.edge_ratio
    xt_edge_attr = (torch.rand(len(probs), device=cfg.device) < probs).long()

    # Suppress 0 edges, then flip
    xt_edge_idx = xt_edge_idx[:, xt_edge_attr > 0]
    xt_edge_attr = xt_edge_attr[xt_edge_attr > 0]

    # Flip and concat with the same ordering as when drawing from p0
    xt_edge_idx = xt_edge_idx.repeat_interleave(2, dim=1)
    xt_edge_idx[:, ::2] = torch.flip(xt_edge_idx[:, ::2], dims=[0])
    xt_edge_attr = xt_edge_attr.repeat_interleave(2, dim=0)

    batch.edge_index = xt_edge_idx
    batch.edge_attr = xt_edge_attr

    # Finally, update _slice_dict as the space occupied by each sample in the batch has been changed
    batch = reset_slice_dict_edges(batch)

    return batch


def noising_node_mask(batch, t):
    # Ensure proper tensor placement
    device = batch.x.device
    batch_size, feature_dim = batch.num_nodes, batch.x.shape[1]
    
    # Generate mask on the same device as batch.x
    mask = torch.rand(batch_size, device=device) < (1 - t)
    # mask =  mask.repeat(1, feature_dim)
    # Create a copy to safely modify batch.x without autograd interference
    corrupted_x = batch.x.clone()
    node_type_mask = cfg.dataset.nnode_types - 1 # Mask is included in nnode_types
    node_feature_mask = cfg.dataset.nnode_features - 1
    corrupted_x[mask] = torch.cat([torch.tensor([node_type_mask]), torch.full((feature_dim - 1,), node_feature_mask)]).to(device)
    
    # Update batch.x with the corrupted version
    batch.x = corrupted_x
    
    return batch


# def deterministic_noising_edge(batch, t, node_mask):

#     # Ensure proper tensor placement
#     device = batch.edge_attr.device
#     edge_size = batch.edge_attr.shape[0]
    
#     # Get indices where node_mask is True
#     node_masked_indices = torch.nonzero(node_mask, as_tuple=True)[0]

#     edge_mask = (node_masked_indices.unsqueeze(1) == batch.edge_index[0]).any(dim=0) | \
#             (node_masked_indices.unsqueeze(1) == batch.edge_index[1]).any(dim=0)

#     edge_mask = edge_mask.to(device)
#     # Generate mask on the same device as batch.edge_index
#     # mask = torch.rand((edge_size), device=device) < (1 - t)
    
#     # Create a copy to safely modify batch.edge_index without autograd interference
#     corrupted_edge_index = batch.edge_attr
#     corrupted_edge_index[edge_mask] = cfg.dataset.nedge_types - 1    
#     # Update batch.edge_index with the corrupted version
#     batch.edge_attr = corrupted_edge_index
    
#     return batch



def noising_edge_mask(batch, nodes_for_graph, t):
    device = batch.edge_index.device

    edge_size = [n * (n - 1) * 0.5 for n in nodes_for_graph]
    edge_size = torch.tensor(edge_size).long()

    # Generate mask on the same device as batch.edge_index
    mask = torch.rand(edge_size.sum().item(), device=device) < (1 - t)
    masked_indices = batch.triu_edge_index[:, mask]
    # Flip and concat (masked edges have been directed so far)
    masked_indices = torch.cat([masked_indices, torch.flip(masked_indices, dims=[0])], dim=1)
    # Generate corresponding mask attribute tensor
    masked_attr = batch.edge_attr.new_full((masked_indices.size(1),), cfg.dataset.nedge_types - 1)

    # Merge existing and masked edge indices & attributes, keep max attribute values aka the mask when indices 
    # are found on both sides
    xt_edge_idx, xt_edge_attr = torch_sparse.coalesce(
        torch.cat([batch.edge_index, masked_indices], dim=1),
        torch.cat([batch.edge_attr, masked_attr], dim=0),
        batch.num_nodes, batch.num_nodes,
        op="max"
    )
    
    batch.edge_index = xt_edge_idx # edge_index now only contains indices of actual (1) and masked (2) edges
    batch.edge_attr = xt_edge_attr # edge_attr now only contains actual (1) and masked (2) attribute values

    # Finally, update _slice_dict as the space occupied by each sample in the batch has changed
    batch = reset_slice_dict_edges(batch)
    
    return batch


def dfm_loss(batch, denoised_batch):

    if cfg.dataset.get("task_type", '') == 'pin_prediction':
        loss_topo = marginal_loss_edges(batch, denoised_batch)
        return loss_topo, loss_topo, torch.tensor(0.0, device=cfg.device)

    ## Feature sizing loss, circuit nodes (In, Out & nets) are excluded.
    loss_sizing = torch.tensor(0.0, device=cfg.device)
    if (cfg.dataset.node_features_dim > 1) or cfg.gt.get("sizing", False):
        not_io_idx = (batch.x[:, 0] != 8) & (batch.x[:, 0] != 9) & (batch.x[:, 0] != 10)
        loss_node_features = F.l1_loss(denoised_batch.x_features[not_io_idx].flatten(), batch.x_features[not_io_idx].flatten(), 
                                       reduction='mean')
        w_f = cfg.loss.feature_weight
        loss_sizing = w_f * loss_node_features
    
    ## Node type loss
    original_node_types = batch.x[:, 0].clone()
    predicted_node_types = denoised_batch.x

    if cfg.gt.node_pruning == 1:
        original_node_types[batch.prunable == 1] = -1 # don't compute the loss on node types for extra nodes
    batch.to(torch.device(cfg.device))
    loss_node_types = F.cross_entropy(predicted_node_types, original_node_types, reduction='mean', ignore_index=-1)

    ## Edge loss
    loss_edges = marginal_loss_edges(batch, denoised_batch)
    w_e = cfg.loss.edge_weight # defog paper suggests > 1

    loss_topo = loss_node_types + w_e * loss_edges

    return loss_sizing + loss_topo, loss_topo, loss_sizing


# def pin_loss(batch):

#     pred_logits_idx = []
#     for i in range(batch.learnable_edge_index.size(1)):
#         edge = batch.learnable_edge_index[:, i].T
#         found = batch.edge_index.T == edge
#         found = torch.nonzero(found[:, 0] & found[:, 1])
#         pred_logits_idx.append(found[0].item())

#     pred_logits = batch.edge_attr[pred_logits_idx]
#     return F.cross_entropy(pred_logits, batch.labels.long(), reduction='mean'), torch.tensor(0.0, device=cfg.device), torch.tensor(0.0, device=cfg.device)


def get_edge_mask(edge_index, node):
    """
    Returns a boolean mask indicating which edges in edge_index contain the given node.

    :param edge_index: Tensor of shape (2, N) representing the graph edges.
    :param node: The node to search for.
    :return: Boolean mask of shape (N,) with True for edges that contain the node.
    """
    mask = (edge_index == node).any(dim=0)  # Controlla se node è presente in una delle due righe
    return mask

def get_paired_nodes(edge_index, target_node):
    # Find all occurrences of the target node
    # Check both source (row 0) and destination (row 1)
    mask_row0 = edge_index[0] == target_node
    mask_row1 = edge_index[1] == target_node
    
    # Get the paired nodes
    paired_from_row0 = edge_index[1, mask_row0]  # If target is in row 0, get corresponding values from row 1
    paired_from_row1 = edge_index[0, mask_row1]  # If target is in row 1, get corresponding values from row 0
    
    # Combine all paired nodes and remove duplicates
    all_paired = torch.cat([paired_from_row0, paired_from_row1])
    unique_paired = torch.unique(all_paired)
    
    return unique_paired


def match_targets(outputs, targets, outputs_edge_attr, target_edge_attribute, edge_index_loss, valid_indices):
    cost_matrix = compute_cost_matrix(outputs, targets, outputs_edge_attr, target_edge_attribute, edge_index_loss, valid_indices)
    row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().detach().numpy())
    return row_ind, col_ind

def compute_cost_matrix(logits, targets, logits_edges, targets_edges, edge_index_loss, valid_indices):
    """
    Computes a cost matrix where each entry (i, j) is the cross-entropy loss
    between logits[i] and targets[j], considering only valid targets (not -1).
    
    Args:
        logits (torch.Tensor): Shape (N, d), raw logits.
        targets (torch.Tensor): Shape (N,), class indices (0 to d-1), with -1 as ignore index.
        logits_edges (torch.Tensor): Shape (E, d_edge), raw logits for edge attributes.
        targets_edges (torch.Tensor): Shape (E,), class indices for edge attributes.
        edge_index_loss (torch.Tensor): Shape (2, E), edge indices for computing loss.
    Returns:
        torch.Tensor: Cost matrix of shape (N, N)
    """
    N, d = logits.shape
    
    # Precompute log softmax once
    log_softmax = F.log_softmax(logits, dim=1)  # Shape: (N, d)
    log_softmax_edges = F.log_softmax(logits_edges, dim=1)  
    
    # Initialize cost matrix with a large value
    # We'll use a large value for invalid targets so they won't be matched
    large_value = 1e9
    cost_matrix = torch.ones(N, N, device=logits.device) * large_value
    
    # For each pair (i, j), compute the cross-entropy loss only if target[j] is valid
    for i in range(N):
        mask_i = get_edge_mask(edge_index_loss, valid_indices[i])
        for j in range(N):
            mask_j = get_edge_mask(edge_index_loss, valid_indices[j])
            # Cross entropy loss between log_softmax[i] and target[j]
            # Calculate node loss from classification
            node_loss = -log_softmax[i, targets[j]]
            
            # Calculate edge loss from the edges connecting nodes i and j
            edge_loss = 0.0
            # Find edges that connect from i to other nodes and from j to other nodes
            edges_i = get_paired_nodes(edge_index_loss[:, mask_i],valid_indices[i])
            edges_j = get_paired_nodes(edge_index_loss[:, mask_j],valid_indices[j])

            if mask_i.any() and mask_j.any():
                if torch.equal(edges_i,edges_j):
                    # Get the corresponding edge attributes for these edges
                    edge_loss = -log_softmax_edges[mask_i, targets_edges[mask_j]].mean()
                elif set(edges_i.tolist()) == set(edges_j.tolist()):
                    print("Warning Hungarian: edge order mismatch")
            # Total cost is node loss plus edge loss
            cost_matrix[i, j] = node_loss + edge_loss
    
    return cost_matrix  # Shape: (N, N)

def hungarian_loss_node(outputs, targets, outputs_edge_attr, target_edge_attribute, edge_index_loss, return_permutation=False):
    N = outputs.shape[0]
    
    # Find indices of valid and invalid targets
    valid_mask = targets != -1
    valid_indices = torch.where(valid_mask)[0].cpu().numpy()
    invalid_indices = torch.where(~valid_mask)[0].cpu().numpy()
    
    # If there are valid targets, perform matching only on them
    if len(valid_indices) > 0:
        # Extract valid outputs and targets
        valid_outputs = outputs[valid_mask]
        valid_targets = targets[valid_mask]
        
        # Get the optimal one-to-one matching for valid targets
        valid_row_ind, valid_col_ind = match_targets(valid_outputs, valid_targets, outputs_edge_attr, target_edge_attribute, edge_index_loss, valid_indices)
        
        # Map valid_row_ind back to original indices
        row_ind = valid_indices[valid_row_ind]
        col_ind = valid_indices[valid_col_ind]
        
        # Use the indices to get matched outputs and targets
        matched_outputs = outputs[row_ind]
        matched_targets = targets[col_ind]
        
        # Calculate the loss (ignore_index=-1 will exclude invalid targets from loss calculation)
        loss = F.cross_entropy(matched_outputs, matched_targets, reduction='mean', ignore_index=-1)
        
        if return_permutation:
            # Initialize permutation to identity (each node stays in place by default)
            permutation = np.arange(N)
            
            # For valid targets, calculate permutation
            valid_perm = row_ind[np.argsort(col_ind)]
            
            # Apply permutation only to valid indices, leaving invalid ones in place
            for i, idx in enumerate(valid_indices):
                permutation[idx] = valid_perm[i]
                
            return loss, permutation
    else:
        # If no valid targets, return zero loss and identity permutation
        loss = torch.tensor(0.0, device=outputs.device)
        if return_permutation:
            return loss, np.arange(N)
    
    # Return only the loss if permutation not requested
    if not return_permutation:
        return loss

def rewrite_edge_index(edge_index, permutation):
    """
    Rewrite edge_index based on node permutation.
    
    Args:
        edge_index (torch.Tensor): Edge index tensor of shape (2, num_edges)
        permutation (numpy.ndarray): Permutation array from Hungarian matching
    
    Returns:
        torch.Tensor: Rewritten edge_index with updated node indices
    """
    # Convert permutation to tensor if it's not already
    if not isinstance(permutation, torch.Tensor):
        permutation = torch.tensor(permutation, device=edge_index.device)
    
    # The permutation tells us: new position → original position
    # We need the inverse: original position → new position
    inverse_perm = torch.empty_like(permutation)
    for new_pos, orig_pos in enumerate(permutation):
        inverse_perm[orig_pos] = new_pos
    
    # Apply the inverse permutation to update node indices in the edge_index
    new_edge_index = edge_index.clone()
    new_edge_index[0] = inverse_perm[edge_index[0]]
    new_edge_index[1] = inverse_perm[edge_index[1]]
    
    return new_edge_index

# Funzione per ottenere la maschera
def get_mask(edge_index, edge_set):
    edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))
    mask = torch.tensor([edge in edge_set for edge in edges], device=edge_index.device)
    return mask

def hungarian_loss(batch, noised_batch, denoised_batch):
    total_loss = 0
    batch_size = batch.num_graphs
    valid_loss = 0
    # Process each graph in the batch separately
    for i in range(batch_size):
        # Get the individual graphs from each batch
        orig_graph = batch.get_example(i)
        noised_graph = noised_batch.get_example(i)
        denoised_graph = denoised_batch.get_example(i)

        # Extract node types
        original_node_types = orig_graph.x[:, 0].clone()
        noised_node_types = noised_graph.x[:, 0]
        predicted_node_types = denoised_graph.x.clone()

        if cfg.train.prior == 'masked':
            # Don't compute loss on already revealed dimensions
            original_node_types[noised_node_types != cfg.dataset.nnode_types - 1] = -1

        edge_index_loss = noised_graph.edge_index[:, noised_graph.edge_attr == cfg.dataset.nedge_types - 1]
        edge_index_loss = edge_index_loss[:, edge_index_loss[0, :] < edge_index_loss[1, :]]
        
        # Converti edge_index_loss in un set di tuple
        edge_index_loss_set = set(zip(edge_index_loss[0].tolist(), edge_index_loss[1].tolist()))

        # Ottieni le maschere
        mask_denoised = get_mask(denoised_graph.edge_index, edge_index_loss_set)
        mask_batch = get_mask(orig_graph.edge_index, edge_index_loss_set)

        mask_duplicates = orig_graph.edge_index[0, :] < orig_graph.edge_index[1, :]
        # Applica le maschere
        filtered_denoised_edge_index = denoised_graph.edge_index#[:, mask_denoised]
        filtered_batch_edge_index = orig_graph.edge_index[:,mask_duplicates]#[:, mask_batch]


        filtered_denoised_edge_attr = denoised_graph.edge_attr#[mask_denoised]
        filtered_batch_edge_attr = orig_graph.edge_attr[mask_duplicates]#[mask_batch]

        denoised_edges =  set(map(tuple, filtered_denoised_edge_index.t().tolist()))
        batch_edges = set(map(tuple, filtered_batch_edge_index.t().tolist()))

        difference = denoised_edges - batch_edges

        # Add the difference edges to filtered_batch_edge_index(the 0 edges in the target)
        if difference:
            difference_tensor = torch.tensor(list(difference), device=filtered_batch_edge_index.device).t()
            filtered_batch_edge_index = torch.cat([filtered_batch_edge_index, difference_tensor], dim=1)
            filtered_batch_edge_attr = torch.cat([filtered_batch_edge_attr, torch.zeros(difference_tensor.size(1), dtype=filtered_batch_edge_attr.dtype, device=filtered_batch_edge_attr.device)])

        filtered_batch_edge_index, filtered_batch_edge_attr  = sort_edge_index(filtered_batch_edge_index, filtered_batch_edge_attr)
        filtered_denoised_edge_index, filtered_denoised_edge_attr = sort_edge_index(filtered_denoised_edge_index, filtered_denoised_edge_attr)

        assert torch.equal(filtered_batch_edge_index, filtered_denoised_edge_index), "Edge index not equal"
        # Compute Hungarian loss for this graph
        _, permutation = hungarian_loss_node(predicted_node_types, original_node_types, filtered_denoised_edge_attr, filtered_batch_edge_attr, filtered_batch_edge_index,  return_permutation=True)

        loss_node = F.cross_entropy(predicted_node_types[permutation], original_node_types, reduction='mean', ignore_index=-1)
        
        # Apply permutation to the denoised graph
        denoised_graph.x = denoised_graph.x[permutation]

        denoised_graph.edge_index = rewrite_edge_index(denoised_graph.edge_index, permutation)
        noised_graph.edge_index = rewrite_edge_index(noised_graph.edge_index, permutation)
        
        # Update the original batched graphs with the modified individual graphs
        # This step depends on how you need to use these graphs later
        # You might need to collect the modified graphs and rebatch them
        
        # Compute edge loss for this graph
        if cfg.train.prior == 'masked':
            loss_edges = mask_loss_edges(orig_graph, noised_graph, denoised_graph)
        else:
            loss_edges = marginal_loss_edges(orig_graph, denoised_graph)
        
        if not torch.isnan(loss_node) and not torch.isnan(loss_edges):
            # Accumulate loss
            total_loss += loss_node + cfg.loss.edge_weight * loss_edges
            valid_loss += 1
        elif not torch.isnan(loss_edges):
            total_loss += cfg.loss.edge_weight * loss_edges
            valid_loss += 1
        elif not torch.isnan(loss_node):
            total_loss += loss_node
            valid_loss += 1

        # Accumulate loss
        # total_loss += loss_node + cfg.loss.edge_weight * loss_edges
    
    if valid_loss == 0:
        return torch.tensor(float('nan'), device=denoised_batch.x.device)
    # Average loss over batch
    return total_loss / valid_loss
    

def mask_loss_edges(batch, noised_batch, denoised_batch):

    # Look for edges on which to apply the loss
    edge_index_loss = noised_batch.edge_index[:, noised_batch.edge_attr == cfg.dataset.nedge_types - 1]
    # print(edge_index_loss)
    edge_index_loss = edge_index_loss[:, edge_index_loss[0, :] < edge_index_loss[1, :]]
    # print(edge_index_loss)

    gt_logits = batch.edge_attr.new_zeros(edge_index_loss.size(1))
    pred_logits_idx = []
    for i in range(len(gt_logits)):
        edge = edge_index_loss[:, i]
        
        ### Batch
        # Is it an edge in the clean graph?
        found = batch.edge_index.T == edge
        found = torch.nonzero(found[:, 0] & found[:, 1])
        if len(found) > 0:
            idx = found[0].item()
            gt_logits[i] = batch.edge_attr[idx]

        ### Denoised batch
        # Find corresponding index in denoised edge_index
        found = denoised_batch.edge_index.T == edge
        found = torch.nonzero(found[:, 0] & found[:, 1])
        # just for the hungarian loss, the index has been update and are no more for sure in the upper triangle
        if len(found) == 0:
            found = denoised_batch.edge_index.T == torch.flip(edge, dims=[0])
            found = torch.nonzero(found[:, 0] & found[:, 1])
        pred_logits_idx.append(found[0].item())
        
    pred_logits = denoised_batch.edge_attr[pred_logits_idx]
    loss_edges = F.cross_entropy(pred_logits, gt_logits, reduction='mean')

    return loss_edges


def marginal_loss_edges(batch, denoised_batch):

    # Densify directed edges - GT edges. Edge indices of denoised_batch have already been sorted in the output layer of the model.
    one_way_edge_idx = batch.edge_index[0] < batch.edge_index[1]
    xt_edge_idx, xt_edge_attr = torch_sparse.coalesce(
        torch.cat([batch.edge_index[:, one_way_edge_idx], batch.triu_edge_index], dim=1),
        torch.cat([batch.edge_attr[one_way_edge_idx], batch.edge_attr.new_zeros(batch.triu_edge_index.size(1))], dim=0),
        batch.num_nodes, batch.num_nodes,
        op="max"
    )

    pred_logits = denoised_batch.edge_attr

    if cfg.dataset.get('task_type', '') == 'pin_prediction':
        xt_edge_attr[batch.triu_learnable_edge_attr == 0] = -1
        loss_edges = F.cross_entropy(pred_logits, xt_edge_attr.long(), reduction='mean', ignore_index=-1)
    else:
        loss_edges = F.cross_entropy(pred_logits, xt_edge_attr, reduction='mean')

    return loss_edges


def sort_edge_index(edge_index, edge_attr=None):

    if edge_index.shape[1] == 0:
        return edge_index, edge_attr

    # Step 1: Ensure each edge pair (u, v) has u <= v
    mask = edge_index[0] > edge_index[1]
    edge_index[:, mask] = edge_index[:, mask].flip(0)
    
    # If there are edge attributes, also flip them accordingly
    if edge_attr is not None and mask.any():
        edge_attr = edge_attr.clone()  # Create a copy to avoid in-place modification issues
    
    # Step 2: Sort edges lexicographically
    # Get the maximum node ID to create a proper offset for lexicographic sorting
    num_nodes = edge_index.max().item() + 1
    
    # Sort using a stable lexicographic key
    _, sorted_indices = torch.sort(edge_index[0] * num_nodes + edge_index[1])
    edge_index = edge_index[:, sorted_indices]
    
    # Reorder edge attributes if they exist
    if edge_attr is not None:
        edge_attr = edge_attr[sorted_indices]
        
    return edge_index, edge_attr