import logging
import os
import time
from typing import Any, Dict
import numpy as np
import torch
from torch_geometric.graphgym.checkpoint import (
    clean_ckpt,
    load_ckpt,
    save_ckpt,
)
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.register import register_train
from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch
from tqdm import tqdm
from ..utils import add_full_rrwp, cfg_to_dict, make_wandb_name, flatten_dict
from torch_geometric.data import Batch
# Pin preprocessing imports moved to master_loader.py
# from ..loader.datasets.analogenie_dataset import NAME_TO_ID_PINS, node2pins, ID_TO_NAME_PINS
import torch.nn.functional as F

# def add_pins(batch: Batch, eval = False):
#     """
#     DEPRECATED: Pin preprocessing is now done in the master loader for efficiency.
#     This function has been moved to preprocess_pins() in master_loader.py
#     """
#     updated_graphs = []
#     for i in range(batch.num_graphs):
#         graph = batch.get_example(i).clone()
#         accepted_connections = []
#         # here I store the connections of the labels, since noe I have the indexes of the pins and I can store
#         # all the couples idx_pin, idx_dst
#         processed_labels = []
#         for node, connections in graph.y:
#             idx_node_type = graph.x[:,0][node].item()
#             node_type = ID_TO_NAME_PINS[idx_node_type]
#             connections_type = node2pins[node_type]
#             # in idx_connections ci sono i type dei nodi che devi aggiungere
#             idx_connections = [NAME_TO_ID_PINS[conn] for conn in connections_type]
#             # Add new nodes to the graph
#             num_new_nodes = len(idx_connections)
#             new_node_features = torch.zeros((num_new_nodes, graph.x.size(1)), dtype=graph.x.dtype, device=graph.x.device)
#             # add node types
#             new_node_features[:, 0] = torch.tensor(idx_connections, dtype=graph.x.dtype, device=graph.x.device)  # set node type
#             #name number store the number that permits to create the name like -> type NM, number 1 -> name: NM1
#             name_number = graph.x[:,1][node].item()
#             new_node_features[:,1] = torch.tensor([name_number]*num_new_nodes, dtype=graph.x.dtype, device=graph.x.device)
#             # Append new nodes to graph.x
#             graph.x = torch.cat([graph.x, new_node_features], dim=0)

#             # Add edges from the original node to each new node
#             new_node_indices = torch.arange(graph.x.size(0) - num_new_nodes, graph.x.size(0), device=graph.x.device)

#             # store labels processed
#             for idx_pin, idx_dst in zip(new_node_indices.tolist(),connections):
#                 processed_labels.append([idx_pin, idx_dst])
#                 accepted_connections.append(connections)

#             src = torch.full((num_new_nodes,), node, dtype=torch.long, device=graph.x.device)
#             dst = new_node_indices

#             # Add edges from original node to new nodes
#             graph.edge_index = torch.cat([graph.edge_index, torch.stack([src, dst], dim=0)], dim=1)
#             # Add edges from new nodes back to original node to make adjacency symmetric
#             graph.edge_index = torch.cat([graph.edge_index, torch.stack([dst, src], dim=0)], dim=1)

#             # Add new edge attributes for the new edges
#             num_new_edges = 2 * num_new_nodes  # forward and backward edges

#             new_edge_attr = torch.tensor([1]*num_new_edges, dtype=graph.edge_attr.dtype, device=graph.edge_attr.device)
#             graph.edge_attr = torch.cat([graph.edge_attr, new_edge_attr], dim=0)

#         graph.y = processed_labels
#         graph.accepted_connections = accepted_connections
#         updated_graphs.append(graph)

#     batch = Batch.from_data_list(updated_graphs)
#     return batch

def train_epoch(logger, loader, model, optimizer, scheduler):
    model.train()
    model.to(torch.device(cfg.device))
    time_start = time.time()
    losses = []
    for idx, batch in enumerate(loader):
        batch.to(torch.device(cfg.device))
        # Pin preprocessing is now done in the master loader for efficiency

        # Count number of nodes per sample
        # npbs = nodes_per_batch_sample(batch)
        optimizer.zero_grad()

        # sample time
        # t, t_x, t_e = sample_t(batch, npbs)
        
        # sample G(t)
        # noised_batch = noising_edge(batch, npbs, t_e)
        # noised_batch.to(torch.device(cfg.device))
        
        # add RRWP
        noised_batch = add_full_rrwp(batch.clone(), walk_length=cfg.posenc_RRWP.ksteps)
        # noised_batch.t = t

        # if cfg.gt.sample_separate_t:
        #     noised_batch = broadcast_t_e(noised_batch, t_e)

        # Denoising 
        denoised_batch = model(noised_batch.clone(), unconditional_prop=cfg.train.ratio_cf_guidance)

        # Loss
        # loss = pin_loss(batch, denoised_batch)
        loss = pin_loss(denoised_batch)
        # test_metric(denoised_batch)


        if not torch.isnan(loss):
            losses.append(loss.item())
        loss.backward()
        optimizer.step()
        logger.update_stats(true=batch.detach().cpu(),
                            pred=denoised_batch.detach().cpu(), loss=loss.item(),
                            lr=scheduler.get_last_lr()[0],
                            time_used=time.time() - time_start,
                            params=cfg.params)
        time_start = time.time()
    scheduler.step()
    return sum(losses)/len(losses)


def eval_epoch(logger, loader, model):
    model.eval()
    time_start = time.time()
    losses = []
    with torch.no_grad():
        for batch in loader:
            batch.to(torch.device(cfg.device))
            # Pin preprocessing is now done in the master loader for efficiency
            
            # add RRWP
            noised_batch = add_full_rrwp(batch.clone(), walk_length=cfg.posenc_RRWP.ksteps)
            
            # Denoising 
            denoised_batch = model(noised_batch.clone(), unconditional_prop=cfg.train.ratio_cf_guidance)
            
            # Loss
            # loss = pin_loss(batch, denoised_batch)
            loss = pin_loss(denoised_batch)
            losses.append(loss.item())
            
            logger.update_stats(true=batch.detach().cpu(),
                                pred=denoised_batch.detach().cpu(), loss=loss.item(),
                                lr=0, time_used=time.time() - time_start,
                                params=cfg.params)
            time_start = time.time()
    return sum(losses)/len(losses)


@register_train('train_pin_prediction')
def train_example(loggers , loaders, model, optimizer, scheduler):
    start_epoch = 0
    if cfg.train.auto_resume:
        start_epoch = load_ckpt(model, optimizer, scheduler,
                                cfg.train.epoch_resume)
    if start_epoch == cfg.optim.max_epoch:
        logging.info('Checkpoint found, Task already done')
    else:
        logging.info('Start from epoch %s', start_epoch)

    
    if cfg.wandb.use:
        try:
            import wandb
        except:
            raise ImportError('WandB is not installed.')
        if cfg.wandb.name == '':
            wandb_name = make_wandb_name(cfg)
        else:
            wandb_name = cfg.wandb.name
        run = wandb.init(entity=cfg.wandb.entity, project=cfg.wandb.project,
                         name=wandb_name)
        run.config.update(cfg_to_dict(cfg))

    num_splits = len(loggers)
    perf = [[] for _ in range(num_splits)]
    eval_loss = None

    with tqdm(range(start_epoch, cfg.optim.max_epoch), desc="Training") as pbar:
        for cur_epoch in pbar:
            train_loss = train_epoch(loggers[0], loaders[0], model, optimizer, scheduler)
            perf[0].append(loggers[0].write_epoch(cur_epoch))
            
            
            if is_eval_epoch(cur_epoch):
                for i in range(1, num_splits):
                    eval_loss = eval_epoch(loggers[i], loaders[i], model)
                    perf[i].append(loggers[i].write_epoch(cur_epoch,custom_metrics=True))

            pbar.set_postfix(train_loss=train_loss, eval_loss=eval_loss)
            
            if is_ckpt_epoch(cur_epoch):
                save_ckpt(model, optimizer, scheduler, cur_epoch)
            
            # Log to wandb if enabled
            if cfg.wandb.use:
                run.log(flatten_dict(perf), step=cur_epoch)
    
    for logger in loggers:
        logger.close()
    if cfg.train.ckpt_clean:
        clean_ckpt()

    if cfg.train.save_final_model:
        ckpt: Dict[str, Any] = {}
        ckpt["model_state"] = model.state_dict()
        if optimizer is not None:
            ckpt["optimizer_state"] = optimizer.state_dict()
        if scheduler is not None:
            ckpt["scheduler_state"] = scheduler.state_dict()

        os.makedirs(cfg.out_dir, exist_ok=True)
        # if cfg.wandb.use:
        #     date = time.strftime("%Y%m%d-%H%M%S")
        #     torch.save(ckpt, cfg.out_dir + f"/{run.id}--{date}.ckpt")
        # else:
        date = time.strftime("%Y%m%d-%H%M%S")
        torch.save(ckpt, cfg.out_dir + f"/{date}_last.ckpt")

    logging.info('Task done, results saved in %s', cfg.run_dir)


def pin_loss(batch):

    pred_logits_idx = []
    for i in range(batch.learnable_edge_index.size(1)):
        edge = batch.learnable_edge_index[:, i].T
        found = batch.edge_index.T == edge
        found = torch.nonzero(found[:, 0] & found[:, 1])
        pred_logits_idx.append(found[0].item())

    pred_logits = batch.edge_attr[pred_logits_idx]
    return F.cross_entropy(pred_logits, batch.labels.long(), reduction='mean')

# def test_metric(batch):
#     all_predictions = []
#     all_true_labels = []
#     pred_logits_idx = []
#     for i in range(batch.learnable_edge_index.size(1)):
#         edge = batch.learnable_edge_index[:, i].T
#         found = batch.edge_index.T == edge
#         found = torch.nonzero(found[:, 0] & found[:, 1])
#         pred_logits_idx.append(found[0].item())

#     pred_logits = batch.edge_attr[pred_logits_idx]
    
#     # Get predicted classes (argmax of logits)
#     predicted_classes = torch.argmax(pred_logits, dim=1)
    
#     # Add to collections
#     all_predictions.extend(predicted_classes.cpu().numpy())
#     all_true_labels.extend(batch.labels.long().cpu().numpy())

    # if batch.num_graphs != predicted_batch.num_graphs :
    #     raise ValueError("Number of graphs in batch and predicted_batch do not match.")
    # total_loss = 0.0
    # criterion = torch.nn.BCEWithLogitsLoss()
    # for i in range(batch.num_graphs):
    #     graph = batch.get_example(i)
    #     predicted_graph = predicted_batch.get_example(i)
        
    #     # Convert target edges to tensor for efficient processing
    #     target_edges = torch.tensor(graph.y, dtype=torch.long, device=graph.x.device)  # Shape: [num_target_edges, 2]
    #     accepted_connections = graph.accepted_connections  # All possible connections for each src
        
    #     # Get edge_index from predicted graph (shape: [2, num_edges])
    #     edge_index = predicted_graph.edge_index
        
    #     # Create a mapping from (src, dst) pairs to edge indices
    #     edge_dict = {}
    #     for edge_idx in range(edge_index.size(1)):
    #         src, dst = edge_index[0, edge_idx].item(), edge_index[1, edge_idx].item()
    #         edge_dict[(src, dst)] = edge_idx
        
    #     # Group target edges by source node
    #     src_to_targets = {}
    #     for j, (src, dst) in enumerate(target_edges):
    #         src_item = src.item()
    #         if src_item not in src_to_targets:
    #             src_to_targets[src_item] = []
    #         src_to_targets[src_item].append(dst.item())
        
    #     # Process each source node and its connections
    #     all_predictions = []
    #     all_labels = []
        
    #     for j, (src, dst) in enumerate(target_edges):
    #         src_item, dst_item = src.item(), dst.item()
            
    #         # Get all possible connections for this src from accepted_connections
    #         possible_connections = accepted_connections[j] if j < len(accepted_connections) else []
            
    #         # For each possible connection from this src
    #         for possible_dst in possible_connections:
    #             # Check both directions for the edge
    #             edge_idx = None
    #             if (src_item, possible_dst) in edge_dict:
    #                 edge_idx = edge_dict[(src_item, possible_dst)]
    #             elif (possible_dst, src_item) in edge_dict:
    #                 edge_idx = edge_dict[(possible_dst, src_item)]
                
    #             if edge_idx is not None:
    #                 # Get prediction for this edge
    #                 pred = predicted_graph.edge_attr[edge_idx]
    #                 all_predictions.append(pred)
                    
    #                 # Label should be 1 if this is the target connection, 0 otherwise
    #                 label = 1.0 if possible_dst == dst_item else 0.0
    #                 all_labels.append(label)
    #             else:
    #                 print(f"Warning: Edge ({src_item}, {possible_dst}) not found in predicted graph")
        
    #     if len(all_predictions) == 0:
    #         continue
            
    #     # Convert to tensors and compute loss
    #     all_predictions = torch.stack(all_predictions)
    #     all_labels = torch.tensor(all_labels, dtype=torch.long, device=graph.x.device)
        
    #     # Compute loss for this graph
    #     graph_loss = F.cross_entropy(all_predictions, all_labels, reduction='mean')
    #     total_loss += graph_loss
    
    # return total_loss / batch.num_graphs