import logging
import time

import numpy as np
import torch
from torch_geometric.graphgym.checkpoint import load_ckpt, save_ckpt, clean_ckpt
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_train
from torch_geometric.graphgym.utils.epoch import is_eval_epoch, is_ckpt_epoch
import torch_geometric
from custom_modules.utils import cfg_to_dict, flatten_dict, make_wandb_name
from torch.cuda.amp import autocast, GradScaler

import torch.distributed as dist
from itertools import chain


def move_tensors_dict_to_cpu(tensor_dict):
    cpu_tensor_dict = {
        key: tensor.detach().to("cpu", non_blocking=True)
        for key, tensor in tensor_dict.items()
    }
    return cpu_tensor_dict


from tqdm import tqdm
import wandb
import psutil
from operator import itemgetter
import copy
import gc
from custom_modules.memory_monitor import MemoryMonitor
import glob
import time


def move_to(data: dict, device):
    for key in data.keys():
        if isinstance(data[key], torch.Tensor):
            data[key] = data[key].to(device)
        elif isinstance(data[key], dict):
            data[key] = move_to(data[key], device)
    return data


def train_epoch(
    epoch,
    loader,
    model,
    optimizer,
    scheduler,
    batch_accumulation,
    logger=None,
    scaler=None,
):
    monitor = MemoryMonitor("mem_out_debug/mem_out_shared_sampler.csv")

    model.train()
    optimizer.zero_grad()
    loader.batch_sampler.set_epoch(epoch)

    # Initialize loss accumulator
    total_loss = 0.0
    loss = torch.tensor(0.0).to(cfg.rank)
    # Count the number of batches
    num_batches = 0
    if cfg.rank == 0:
        loader = tqdm(loader)

    accum_loss = 0.0
    accum_step = 0
    ooms = 0

    for iter, batch in enumerate(loader):

        x_dict = {}
        graph_pos = []
        batch_index = []

        for graph_id, graph_name in enumerate(batch["graph_names"]):

            x_dict[graph_name] = batch["main_graph_dict"][graph_name].x
            graph_pos.append(batch["main_graph_dict"][graph_name].eigvecs_sn)
            batch_index.append(
                torch.ones(
                    batch["main_graph_dict"][graph_name].num_nodes, dtype=torch.long
                )
                * graph_id
            )

        graph_pos = torch.cat(graph_pos, dim=0)
        batch_index = torch.cat(batch_index, dim=0)

        batch["x_dict"] = x_dict
        batch["pos"] = graph_pos
        batch["batch_index"] = batch_index

        batch.pop("main_graph_dict")
        oom = False
        try:
            if (iter + 1) % cfg.train.accum_gradient_steps == 0:
                batch = move_to(batch, cfg.rank)
                with autocast(enabled=True, dtype=torch.bfloat16):
                    _, loss, _, _, _ = model(batch)
                loss.backward()
            else:
                with model.no_sync():
                    batch = move_to(batch, cfg.rank)
                    with autocast(enabled=True, dtype=torch.bfloat16):
                        _, loss, _, _, _ = model(batch)
                    loss.backward()

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(e)
                print(
                    f"| WARNING: ran out of memory, skipping batch, acc step: {accum_step}, gpu rank: {cfg.rank}"
                )
                ooms += 1
                oom = True

            else:
                raise e
        if oom:
            del batch, loss, _
            loss = torch.tensor(0.0).to(cfg.rank)
            model._clear_grad_buffer()
            optimizer.zero_grad(set_to_none=None)
            torch.cuda.empty_cache()

        # batch = move_to(batch, cfg.rank)
        # with autocast(enabled=True, dtype=torch.bfloat16):
        #     _, loss, _, _ = model(batch)
        # loss.backward()

        wandb_loss = loss.detach()
        accum_step += 1
        dist.all_reduce(wandb_loss, op=dist.ReduceOp.SUM)
        accum_loss += wandb_loss.cpu().detach().item()
        if (iter + 1) % cfg.train.accum_gradient_steps == 0:

            optimizer.step()

            optimizer.zero_grad()

            if cfg.wandb.use and cfg.rank == 0:
                wandb.log({"train_loss_per_step": accum_loss})
                accum_loss = 0.0
            elif cfg.wandb.use == False and cfg.rank == 0:
                print(f'"train_loss_per_step": {accum_loss}')
                accum_loss = 0.0
            accum_step = 0

        loss = loss.detach().cpu().item()
        total_loss += loss
        num_batches += 1

    # ooms = dist.all_reduce(ooms, op=dist.ReduceOp.SUM)
    # if cfg.rank == 0:
    #     if cfg.wandb.use:
    #         wandb.log({"ooms": ooms, "epoch": epoch})
    #     else:
    #         print(f"ooms: {ooms}")

    num_batches = num_batches
    avg_loss = total_loss / num_batches if num_batches > 0 else 0

    return avg_loss


@torch.no_grad()
def eval_epoch(loader, model, logger=None):
    model.eval()
    time_start = time.time()
    with torch.no_grad():
        for batch in loader:
            for taskname, output in batch["batch_graph"].items():
                batch["batch_graph"][taskname] = batch["batch_graph"][taskname].to(
                    cfg.rank
                )
            for taskname, output in batch["output_values"].items():
                batch["output_values"][taskname] = batch["output_values"][taskname].to(
                    cfg.rank
                )

            batch["pos_latents"] = batch["pos_latents"].to(cfg.rank)
            batch["node_dataset_indices"] = batch["node_dataset_indices"].to(cfg.rank)
            batch["k_hop_neigh_idx"] = batch["k_hop_neigh_idx"].to(cfg.rank)
            batch["k_hop_seq_mask"] = batch["k_hop_seq_mask"].to(cfg.rank)
            true, loss, losses_taskwise, pred_score = model(batch)

            extra_stats = {}
            _true = move_tensors_dict_to_cpu(true)
            _pred = move_tensors_dict_to_cpu(pred_score)
            losses_taskwise = move_tensors_dict_to_cpu(losses_taskwise)
            logger.update_stats(
                true=_true,
                pred=_pred,
                batch_size=loader.batch_size,
                losses_taskwise=losses_taskwise,
                total_loss=loss.detach().cpu().item(),
                lr=0,
                time_used=time.time() - time_start,
                params=cfg.params,
                **extra_stats,
            )
            time_start = time.time()


from collections import defaultdict

# def gather_all_keys(local_keys, rank, world_size):
#     # Gather all keys from all processes to find unique keys
#     gathered_keys = [None] * world_size
#     dist.all_gather_object(gathered_keys, local_keys)

#     # Flatten the list of lists and get unique keys
#     all_keys = set(key for sublist in gathered_keys for key in sublist)
#     return all_keys


def gather_all_keys(local_keys, rank, world_size):
    # Ensure that local_keys is always a list, even if it's empty
    local_keys = local_keys if local_keys is not None else []

    # Prepare a list to gather all keys from all processes
    gathered_keys = [None] * world_size

    # Use all_gather to collect keys from all processes
    dist.all_gather_object(gathered_keys, local_keys)

    # Flatten the list and get unique keys
    # Ensure no None values are included in the list comprehension
    all_keys = set(
        key for sublist in gathered_keys if sublist is not None for key in sublist
    )
    return all_keys


def eval_epoch_for_distributed_only_loss(epoch, loader, model, logger=None):
    model.eval()
    # model.train()
    # find all batchnorm and set to eval
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.BatchNorm1d) or isinstance(m, torch.nn.BatchNorm2d):
            print(f"setting {name} to train")
            m.train()
        # find dropout
        # if isinstance(m, torch.nn.Dropout):
        #     print(f"setting {name} to eval")
        #     m.eval()
    loader.batch_sampler.set_epoch(epoch)
    time_start = time.time()
    losses_taskwise_dict = {}
    losses_count_dict = {}
    if cfg.rank == 0:
        loader = tqdm(loader)
    local_losses = defaultdict(float)
    local_counts = defaultdict(int)
    total_loss = 0.0
    num_batches = 0
    for iter, batch in enumerate(loader):
        x_dict = {}
        graph_pos = []
        batch_index = []
        for graph_id, graph_name in enumerate(batch["graph_names"]):

            x_dict[graph_name] = batch["main_graph_dict"][graph_name].x
            graph_pos.append(batch["main_graph_dict"][graph_name].eigvecs_sn)
            batch_index.append(
                torch.ones(
                    batch["main_graph_dict"][graph_name].num_nodes, dtype=torch.long
                )
                * graph_id
            )

        graph_pos = torch.cat(graph_pos, dim=0)
        batch_index = torch.cat(batch_index, dim=0)

        batch["x_dict"] = x_dict
        batch["pos"] = graph_pos
        batch["batch_index"] = batch_index
        batch.pop("main_graph_dict")
        batch = move_to(batch, cfg.rank)
        with torch.inference_mode():
            with autocast(enabled=True, dtype=torch.bfloat16):
                _, loss, losses_taskwise, _, num_elements_taskwise = model(batch)

        # Update local dictionaries
        for key, value in losses_taskwise.items():
            local_losses[key] += value.item()
            local_counts[key] += num_elements_taskwise[key]

        loss = loss.detach().cpu().item()
        total_loss += loss
        num_batches += 1

    num_batches = num_batches
    avg_loss = total_loss / num_batches if num_batches > 0 else 0

    # Synchronize all keys across GPUs
    all_keys = sorted(
        gather_all_keys(list(local_losses.keys()), cfg.rank, cfg.world_size)
    )

    # Standardize the local dictionaries to include all keys
    for key in all_keys:
        if key not in local_losses:
            local_losses[key] = 0.0
            local_counts[key] = 0
    # Convert dictionaries to tensors for reduction
    losses_tensor = torch.tensor([local_losses[key] for key in all_keys], device="cuda")
    counts_tensor = torch.tensor([local_counts[key] for key in all_keys], device="cuda")

    # Reduce sums across all processes
    dist.all_reduce(losses_tensor, op=dist.ReduceOp.SUM)
    dist.all_reduce(counts_tensor, op=dist.ReduceOp.SUM)

    # Calculate average losses
    averaged_losses_taskwise_dict = {
        f"train_loss_per_dataset/{key}": (
            (losses_tensor[i] / counts_tensor[i]).item()
            if counts_tensor[i] > 0
            else 0.0
        )
        for i, key in enumerate(all_keys)
    }

    return averaged_losses_taskwise_dict, avg_loss


# @torch.no_grad()
# def eval_epoch_for_distributed_only_loss(epoch,loader, model, logger=None):
#     model.eval()
#     # loader.batch_sampler.set_epoch(epoch)
#     time_start = time.time()
#     losses_taskwise_dict = {}
#     losses_count_dict = {}

#     loader = tqdm(loader)
#     local_losses = defaultdict(float)
#     local_counts = defaultdict(int)
#     total_loss = 0.0
#     num_batches = 0
#     with torch.no_grad():
#         for iter, batch in enumerate(loader):

#             x_dict = {}
#             graph_pos = []
#             batch_index = []
#             for graph_id, graph_name in enumerate(batch["graph_names"]):

#                 x_dict[graph_name] = batch["main_graph_dict"][graph_name].x
#                 graph_pos.append(batch["main_graph_dict"][graph_name].eigvecs_sn)
#                 batch_index.append(
#                     torch.ones(
#                         batch["main_graph_dict"][graph_name].num_nodes, dtype=torch.long
#                     )
#                     * graph_id
#                 )

#             graph_pos = torch.cat(graph_pos, dim=0)
#             batch_index = torch.cat(batch_index, dim=0)

#             batch["x_dict"] = x_dict
#             batch["pos"] = graph_pos
#             batch["batch_index"] = batch_index

#             batch.pop("main_graph_dict")
#             batch = move_to(batch, cfg.rank)
#             with autocast(enabled=True, dtype=torch.bfloat16):
#                 _, loss, losses_taskwise, _ = model(batch)

#             # Update local dictionaries
#             for key, value in losses_taskwise.items():
#                 local_losses[key] += value.item()
#                 local_counts[key] += 1

#             loss = loss.detach().cpu().item()
#             total_loss += loss
#             num_batches += 1

#         num_batches = num_batches
#         avg_loss = total_loss / num_batches if num_batches > 0 else 0

#         # Synchronize all keys across GPUs
#         # all_keys = gather_all_keys(list(local_losses.keys()), cfg.rank, cfg.world_size)
#         all_keys = set(local_losses.keys())

#         # Standardize the local dictionaries to include all keys
#         for key in all_keys:
#             if key not in local_losses:
#                 local_losses[key] = 0.0
#                 local_counts[key] = 0

#         # Convert dictionaries to tensors for reduction
#         losses_tensor = torch.tensor([local_losses[key] for key in all_keys], device='cuda')
#         counts_tensor = torch.tensor([local_counts[key] for key in all_keys], device='cuda')

#         # # Reduce sums across all processes
#         # dist.all_reduce(losses_tensor, op=dist.ReduceOp.SUM)
#         # dist.all_reduce(counts_tensor, op=dist.ReduceOp.SUM)

#         # Calculate average losses
#         averaged_losses_taskwise_dict = {f'train_loss_per_dataset/{key}': (losses_tensor[i] / counts_tensor[i]).item() if counts_tensor[i] > 0 else 0.0
#                         for i, key in enumerate(all_keys)}

#         return averaged_losses_taskwise_dict,avg_loss


@torch.no_grad()
def eval_epoch_for_distributed(loader, model, logger=None):
    model.eval()
    time_start = time.time()
    with torch.no_grad():
        for batch in loader:
            for taskname, output in batch["batch_graph"].items():
                batch["batch_graph"][taskname] = batch["batch_graph"][taskname].to(
                    cfg.rank
                )
            for taskname, output in batch["output_values"].items():
                batch["output_values"][taskname] = batch["output_values"][taskname].to(
                    cfg.rank
                )
            batch["pos_latents"] = batch["pos_latents"].to(cfg.rank)
            batch["node_dataset_indices"] = batch["node_dataset_indices"].to(cfg.rank)
            batch["k_hop_neigh_idx"] = batch["k_hop_neigh_idx"].to(cfg.rank)
            batch["k_hop_seq_mask"] = batch["k_hop_seq_mask"].to(cfg.rank)
            true, loss, losses_taskwise, pred_score = model(batch)

            extra_stats = {}
            _true = move_tensors_dict_to_cpu(true)
            _pred = move_tensors_dict_to_cpu(pred_score)
            losses_taskwise = move_tensors_dict_to_cpu(losses_taskwise)
            logger.update_stats(
                true=_true,
                pred=_pred,
                batch_size=loader.batch_size,
                losses_taskwise=losses_taskwise,
                total_loss=loss.detach().cpu().item(),
                lr=0,
                time_used=time.time() - time_start,
                params=cfg.params,
                **extra_stats,
            )
            time_start = time.time()


def load_ckpt(
    model,
    optimizer,
    scheduler,
    epoch: int,
    run_folder,
):
    # path = f"{run_folder}/files/ckpt/{epoch}.ckpt"
    # ckpt = torch.load(path)
    # model.load_state_dict(ckpt["model_state"])

    path = f"{run_folder}/files/ckpt/{epoch}.ckpt"
    ckpt = torch.load(path, map_location="cpu")
    # model_dict = model.state_dict()  # current model state dict
    ckpt_state_dict = ckpt["model_state"]

    # # Filter out unnecessary keys
    # pretrained_model_dict = {}
    # not_updated_keys = []

    # for k, _ in model_dict.items():
    #     if k in ckpt_state_dict:
    #         if model_dict[k].size() == ckpt_state_dict[k].size():
    #             pretrained_model_dict[k] = ckpt_state_dict[k]
    #         else:
    #             not_updated_keys.append((k, "size mismatch"))
    #     else:
    #         not_updated_keys.append((k, "missing in ckpt"))

    # if not_updated_keys:
    #     for key, reason in not_updated_keys:
    #         print(f"{key}: {reason}")
    #         raise ValueError("common backbone params not updated")

    # # Update the current model's state dict with the filtered state dict
    # model_dict.update(pretrained_model_dict)
    model.load_state_dict(ckpt_state_dict)

    path = f"{run_folder}/files/ckpt/{epoch}.ckpt"
    ckpt = torch.load(path)
    if "optimizer_state" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer_state"])
    if "scheduler_state" in ckpt:
        scheduler.load_state_dict(ckpt["scheduler_state"])

    return epoch + 1


from prettytable import PrettyTable


@register_train("custom_multi_dataset_node_class_distributed")
def custom_train(train_loader, model, optimizer, scheduler, start_epoch=0):
    """
    Customized training pipeline.

    Args:
        loggers: List of loggers
        loaders: List of loaders
        model: GNN model
        optimizer: PyTorch optimizer
        scheduler: PyTorch learning rate scheduler

    """
    start_epoch = 0
    # if cfg.train.auto_resume:
    #     start_epoch = load_ckpt(model, optimizer, scheduler, cfg.train.epoch_resume)
    # if start_epoch == cfg.optim.max_epoch:
    #     logging.info("Checkpoint found, Task already done")
    # else:

    if cfg.model.pretrained_model_run_id == "None":
        start_epoch = 0
    else:
        # grab the model folder
        run_folder = glob.glob(
            f"~/*/*/distributed_results/*/*/*{cfg.model.pretrained_model_run_id}"
        )[0]
        print(run_folder)
        start_epoch = load_ckpt(
            model.module,
            optimizer,
            scheduler,
            cfg.model.pretrained_epoch,
            run_folder,
        )
        print(f"loaded model from {run_folder}")
        # scheduler.step(start_epoch)

        scheduler.step(75)

        # Check the learning rate from the optimizer
        lr_log = {}
        logged_names = set()  # Set to keep track of logged names

        if cfg.rank == 0:
            for group in optimizer.param_groups:
                group_name = group.get(
                    "name", "base_perceiver"
                )  # Ensure there's a default name
                # Check if the name has already been logged
                if group_name not in logged_names:
                    lr_log[f"LR/{group_name}"] = group["lr"]
                    logged_names.add(group_name)  # Mark this name as logged
            # Create a PrettyTable
            table = PrettyTable()
            table.field_names = ["Group Name", "Learning Rate"]

            # Fill the table with data
            for name, lr in lr_log.items():
                table.add_row([name.split("/")[-1], lr])
            print(table)

    logging.info("Start from epoch %s", start_epoch)

    if cfg.wandb.use:
        try:
            import wandb
        except:
            raise ImportError("WandB is not installed.")

    dataset_config_list = []
    for dataset_name in cfg.dataset_multi.name_list:
        dataset_cfg = getattr(cfg, dataset_name)
        dataset_cfg.enable = True
        dataset_config_list.append(dataset_cfg)

    for cur_epoch in range(start_epoch, cfg.optim.max_epoch):

        if cfg.rank == 0 and cfg.wandb.use:
            lr_log = {}
            logged_names = set()  # Set to keep track of logged names

            for group in optimizer.param_groups:
                group_name = group.get(
                    "name", "base_perceiver"
                )  # Ensure there's a default name
                # Check if the name has already been logged
                if group_name not in logged_names:
                    lr_log[f"LR/{group_name}"] = group["lr"]
                    logged_names.add(group_name)  # Mark this name as logged
            lr_log["epoch"] = cur_epoch
            wandb.log(lr_log)

        avg_loss = train_epoch(
            cur_epoch,
            train_loader,
            model,
            optimizer,
            scheduler,
            cfg.optim.batch_accumulation,
            None,
            None,
        )
        scheduler.step()
        if cfg.rank == 0:
            save_ckpt(model.module, optimizer, scheduler, cur_epoch)
            if cfg.wandb.use:
                wandb.log({"train_loss": avg_loss, "epoch": cur_epoch})
            print(f"train_loss: {avg_loss:.4f}")

    if cfg.rank == 0:
        if cfg.wandb.use:
            wandb.finish()
            run = None

    logging.info("Task done, results saved in %s", cfg.run_dir)
