import logging
from tqdm import tqdm
import wandb
import torch
from torch_geometric.graphgym.checkpoint import save_ckpt
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_train
from torch.cuda.amp import autocast
import torch.distributed as dist


def move_tensors_dict_to_cpu(tensor_dict):
    cpu_tensor_dict = {
        key: tensor.detach().to("cpu", non_blocking=True)
        for key, tensor in tensor_dict.items()
    }
    return cpu_tensor_dict


def move_to(data: dict, device):
    for key in data.keys():
        if isinstance(data[key], torch.Tensor):
            data[key] = data[key].to(device)
        elif isinstance(data[key], dict):
            data[key] = move_to(data[key], device)
    return data


def process_batch(batch):

    graph_names = batch["graph_names"]
    graph_name_both = batch["graph_name_both"]
    main_graph_dict = batch["main_graph_dict"]
    key_to_idx = batch["key_to_idx"]

    total_nodes = 0
    graph_data_list = []
    for graph_name in graph_names:
        if isinstance(graph_name, str):
            graph_data = main_graph_dict[key_to_idx[graph_name]]
        else:
            graph_data = graph_name

        total_nodes += graph_data.num_nodes

        graph_data_list.append(graph_data)

    # Preallocate lists and tensors
    x_dict = {}
    graph_pos = []
    graph_pos_type = []
    batch_index = torch.empty(total_nodes, dtype=torch.long)
    batch_pos_type = torch.empty(total_nodes, dtype=torch.long)

    node_offset = 0

    for graph_id, graph_data in enumerate(graph_data_list):

        x_dict[graph_name_both[graph_id]] = graph_data.x
        graph_pos.append(graph_data.eigvecs_sn)
        graph_pos_type.append(graph_data.pos_type)

        num_nodes = graph_data.num_nodes

        batch_index[node_offset : node_offset + num_nodes] = graph_id
        batch_pos_type[node_offset : node_offset + num_nodes] = graph_data.pos_type[
            -1
        ].item()

        node_offset += num_nodes

    graph_pos = torch.cat(graph_pos, dim=0)
    graph_pos_type = torch.cat(graph_pos_type, dim=0)

    return x_dict, graph_pos, graph_pos_type, batch_index, batch_pos_type


def train_epoch(
    epoch,
    loader,
    model,
    optimizer,
):

    model.train()
    optimizer.zero_grad()
    loader.batch_sampler.set_epoch(epoch)

    # Initialize loss accumulator
    total_loss = 0.0
    loss = torch.tensor(0.0).to(cfg.rank)
    # Count the number of batches
    num_batches = 0
    if cfg.rank == 0:
        loader = tqdm(loader)

    accum_loss = 0.0
    accum_step = 0
    ooms = 0
    for iter, batch in enumerate(loader):
        x_dict, graph_pos, graph_pos_type, batch_index, batch_pos_type = process_batch(
            batch
        )
        batch["x_dict"] = x_dict
        batch["pos"] = graph_pos
        batch["pos_type"] = graph_pos_type
        batch["batch_pos_type"] = batch_pos_type
        batch["edge_index"] = None
        batch["batch_index"] = batch_index
        batch.pop("main_graph_dict")
        oom = False
        try:
            if (iter + 1) % cfg.train.accum_gradient_steps == 0:
                batch = move_to(batch, cfg.rank)
                with autocast(enabled=True, dtype=torch.bfloat16):
                    _, loss, _, _, _ = model(batch)
                loss.backward()
            else:
                with model.no_sync():
                    batch = move_to(batch, cfg.rank)
                    with autocast(enabled=True, dtype=torch.bfloat16):
                        _, loss, _, _, _ = model(batch)
                    loss.backward()

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(e)
                print(
                    f"| WARNING: ran out of memory, skipping batch, acc step: {accum_step}, gpu rank: {cfg.rank}"
                )
                ooms += 1
                oom = True
                print(batch["graph_name_both"])

            else:
                raise e
        if oom:
            del batch, loss, _
            loss = torch.tensor(0.0).to(cfg.rank)
            model._clear_grad_buffer()
            optimizer.zero_grad(set_to_none=None)
            torch.cuda.empty_cache()

        wandb_loss = loss.detach()
        accum_step += 1
        dist.all_reduce(wandb_loss, op=dist.ReduceOp.SUM)
        accum_loss += wandb_loss.cpu().detach().item()
        if (iter + 1) % cfg.train.accum_gradient_steps == 0:

            optimizer.step()

            optimizer.zero_grad()

            if cfg.wandb.use and cfg.rank == 0:
                wandb.log({"train_loss_per_step": accum_loss})
                accum_loss = 0.0
            elif cfg.wandb.use == False and cfg.rank == 0:
                print(f'"train_loss_per_step": {accum_loss}')
                accum_loss = 0.0
            accum_step = 0

        loss = loss.detach().cpu().item()
        total_loss += loss
        num_batches += 1

    num_batches = num_batches
    avg_loss = total_loss / num_batches if num_batches > 0 else 0

    return avg_loss


@register_train("train_distributed")
def custom_train(train_loader, model, optimizer, scheduler, start_epoch=0):
    """
    Customized training pipeline.

    Args:
        loggers: List of loggers
        loaders: List of loaders
        model: GNN model
        optimizer: PyTorch optimizer
        scheduler: PyTorch learning rate scheduler

    """
    start_epoch = 0

    if cfg.wandb.use:
        try:
            import wandb
        except:
            raise ImportError("WandB is not installed.")

    dataset_config_list = []
    for dataset_name in cfg.dataset_multi.name_list:
        dataset_cfg = getattr(cfg, dataset_name)
        dataset_cfg.enable = True
        dataset_config_list.append(dataset_cfg)

    for cur_epoch in range(start_epoch, cfg.optim.max_epoch):

        if cfg.rank == 0 and cfg.wandb.use:
            lr_log = {}
            logged_names = set()  # Set to keep track of logged names

            for group in optimizer.param_groups:
                group_name = group.get(
                    "name", "base_perceiver"
                )  # Ensure there's a default name
                # Check if the name has already been logged
                if group_name not in logged_names:
                    lr_log[f"LR/{group_name}"] = group["lr"]
                    logged_names.add(group_name)  # Mark this name as logged
            lr_log["epoch"] = cur_epoch
            wandb.log(lr_log)
        print("starting epoch")
        avg_loss = train_epoch(
            cur_epoch,
            train_loader,
            model,
            optimizer,
        )
        scheduler.step()
        if cfg.rank == 0:
            save_ckpt(model.module, optimizer, scheduler, cur_epoch)
            if cfg.wandb.use:
                wandb.log({"train_loss": avg_loss, "epoch": cur_epoch})
            print(f"train_loss: {avg_loss:.4f}")

    if cfg.rank == 0:
        if cfg.wandb.use:
            wandb.finish()
            run = None

    logging.info("Task done, results saved in %s", cfg.run_dir)
