import logging
import resource
import os
import re
import numpy as np

import glob
from prettytable import PrettyTable
import torch
import torch.multiprocessing as mp

mp.set_sharing_strategy("file_system")
import torch.distributed as dist

from torch.distributed.utils import _cast_forward_inputs

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.join import Join
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import (
    cfg,
    set_cfg,
    load_cfg,
)
from torch_geometric import seed_everything
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.register import train_dict


from torch_geometric.graphgym.optim import (
    create_scheduler,
    OptimizerConfig,
)

from custom_modules.optimizer.extra_optimizers import (
    create_optimizer,
    ExtendedSchedulerConfig,
)

from custom_modules.loader.utils import (
    load_dataset_from_pt,
    load_dataset_from_pt_syn,
    GraphSAINTRandomWalkSampler_custom,
)
from custom_modules.loader.custom_loaders import create_loader_distributed
from custom_modules.loader.node_loader import NodeDataset
from custom_modules.utils import cfg_to_dict

logger = logging.getLogger(__name__)


class CustomDDP(DDP):

    def _pre_forward(self, *inputs, **kwargs):
        if not self._lazy_init_ran:
            self._lazy_init()
        if self._delay_all_reduce_all_params:
            return inputs, kwargs

        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            assert self.logger is not None
            self.logger.set_runtime_stats_and_log()
            self.reducer.prepare_for_forward()

        work = Join.notify_join_context(self)
        if work:
            self.reducer._set_forward_pass_work_handle(
                work, self._divide_by_initial_world_size  # type: ignore[arg-type]
            )

        if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
            logger.info("Reducer buckets have been rebuilt in this iteration.")
            self._has_rebuilt_buckets = True

        if self._check_sync_bufs_pre_fwd():
            self._sync_buffers()

        if self._join_config.enable:
            # Notify joined ranks whether they should sync in backwards pass or not.
            self._check_global_requires_backward_grad_sync(is_joined_rank=False)

        if self.mixed_precision is not None:
            inputs, kwargs = _cast_forward_inputs(
                self.mixed_precision.param_dtype,
                *inputs,
                **kwargs,
            )
        return inputs, kwargs


def new_optimizer_config(cfg):
    return OptimizerConfig(
        optimizer=cfg.optim.optimizer,
        base_lr=cfg.optim.base_lr,
        weight_decay=cfg.optim.weight_decay,
        momentum=cfg.optim.momentum,
    )


def new_scheduler_config(cfg):
    return ExtendedSchedulerConfig(
        scheduler=cfg.optim.scheduler,
        steps=cfg.optim.steps,
        lr_decay=cfg.optim.lr_decay,
        max_epoch=cfg.optim.max_epoch,
        reduce_factor=cfg.optim.reduce_factor,
        schedule_patience=cfg.optim.schedule_patience,
        min_lr=cfg.optim.min_lr,
        num_warmup_epochs=cfg.optim.num_warmup_epochs,
        train_mode=cfg.train.mode,
        eval_period=cfg.train.eval_period,
    )


def custom_set_out_dir(cfg, cfg_fname, name_tag):
    """Set custom main output directory path to cfg.
    Include the config filename and name_tag in the new :obj:`cfg.out_dir`.

    Args:
        cfg (CfgNode): Configuration node
        cfg_fname (string): Filename for the yaml format configuration file
        name_tag (string): Additional name tag to identify this execution of the
            configuration file, specified in :obj:`cfg.name_tag`
    """
    run_name = os.path.splitext(os.path.basename(cfg_fname))[0]
    run_name += f"-{name_tag}" if name_tag else ""
    cfg.out_dir = os.path.join(cfg.out_dir, run_name)


def custom_set_run_dir(cfg, run_id):
    """Custom output directory naming for each experiment run.

    Args:
        cfg (CfgNode): Configuration node
        run_id (int): Main for-loop iter id (the random seed or dataset split)
    """
    cfg.run_dir = os.path.join(cfg.out_dir, str(run_id))
    os.makedirs(cfg.run_dir, exist_ok=True)


def params_count_no_mlp(model):
    """Computes the number of parameters, excluding those in 'feat_emb' or 'readout'.

    Args:
        model (nn.Module): PyTorch model

    Returns:
        int: Number of parameters excluding specific names.
    """
    return sum(
        p.numel()
        for name, p in model.named_parameters()
        if "feat_emb" not in name and "readout" not in name
    )


def calculate_lr(
    num_nodes,
    interpolation="linear",
    min_nodes=256,
    max_nodes=20000,
    min_lr=0.0001,
    max_lr=0.001,
):
    if num_nodes <= min_nodes:
        return max_lr
    elif num_nodes >= max_nodes:
        return min_lr
    else:
        # Linear interpolation
        if interpolation == "linear":
            lr = max_lr - (
                (num_nodes - min_nodes) * (max_lr - min_lr) / (max_nodes - min_nodes)
            )
            return lr
        elif interpolation == "custom":
            return lr
        elif interpolation == "log":
            log_ratio = (np.log(num_nodes) - np.log(min_nodes)) / (
                np.log(max_nodes) - np.log(min_nodes)
            )
            lr = max_lr - (log_ratio * (max_lr - min_lr))
        return lr


def create_param_groups(
    model,
    nodes_num_dict,
    multi_graph=True,
    min_lr=0.0001,
    max_lr=0.001,
    base_lr=0.0001,
    lr_dict=None,
):
    if multi_graph:

        param_groups = []
        dataset_names = nodes_num_dict.keys()
        num_nodes_list = []
        for data_name in dataset_names:
            num_nodes_list.append(nodes_num_dict[data_name])
        for name, param in model.named_parameters():
            matched = False
            if param.requires_grad:
                for data_name in dataset_names:
                    if data_name + "." in name:
                        print(name)
                        try:
                            param_groups.append(
                                {
                                    "params": param,
                                    "lr": lr_dict[data_name],
                                    "name": data_name,
                                    "sparse": True,
                                    "parameter_name": name,
                                }
                            )
                        except:
                            param_groups.append(
                                {
                                    "params": param,
                                    "lr": calculate_lr(
                                        nodes_num_dict[data_name],
                                        interpolation="log",
                                        min_nodes=np.min(num_nodes_list),
                                        max_nodes=np.max(num_nodes_list),
                                        min_lr=min_lr,
                                        max_lr=max_lr,
                                    ),
                                    "name": data_name,
                                    "sparse": True,
                                    "parameter_name": name,
                                }
                            )
                        matched = True
                        break
                if not matched:
                    param_groups.append(
                        {
                            "params": param,
                            "lr": base_lr,
                            "name": "base_perceiver",
                            "sparse": False,
                            "parameter_name": name,
                        }
                    )
        return param_groups
    else:
        param_groups = []
        for name, param in model.named_parameters():
            if param.requires_grad:
                param_groups.append(
                    {
                        "params": param,
                        "lr": base_lr,
                        "name": "base_perceiver",
                        "sparse": False,
                        "parameter_name": name,
                    }
                )

        return param_groups


def cleanup():
    dist.destroy_process_group()


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12366"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def main(
    rank,
    world_size,
    shared_graph_dataset_list,
    key_to_idx,
    saint_samplers,
    node_dataset_loaded_list_train,
    nodes_num_dict,
    syn_graph_config,
    optim_lr_dict=None,
):
    from torch_geometric.graphgym.config import (
        cfg,
        dump_cfg,
        set_cfg,
        load_cfg,
    )

    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
    # Load cmd line args
    args = parse_args()
    # Load config file
    set_cfg(cfg)
    load_cfg(cfg, args)
    custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag)

    # setup the process groups
    setup(rank, world_size)

    cfg.dataset.split_index = 0
    cfg.seed = 0
    cfg.run_id = 0
    cfg.rank = rank
    cfg.world_size = world_size
    if cfg.dataset_multi.use_synthetic:
        cfg.dataset_multi.syn_graph_config = syn_graph_config

    seed_everything(cfg.seed)

    if cfg.rank == 0:

        if cfg.wandb.use:
            try:
                import wandb
            except:
                raise ImportError("WandB is not installed.")

            custom_dir = f"{cfg.out_dir}/distributed_results/wandblogging"
            os.makedirs(custom_dir, exist_ok=True)
            wandb.init(
                project=cfg.wandb.project,
                dir=custom_dir,
            )
            wandb.config.update(cfg_to_dict(cfg))
            pattern = r"run-([^/]+)/files"
            run_name = re.search(pattern, wandb.run.dir)
            cfg.run_dir = f"{cfg.out_dir}/distributed_results/local_folder/wandb/{run_name[0]}"  # seperate local folder to save model (done to avoid uploading to wandb)
            cfg.out_dir = wandb.run.dir

            # define our custom x axis metric
            wandb.define_metric("epoch")
            # define which metrics will be plotted against it
            wandb.define_metric("train_loss", step_metric="epoch")
            wandb.define_metric("LR/*", step_metric="epoch")
            wandb.define_metric("ooms", step_metric="epoch")
        else:
            custom_set_run_dir(cfg, cfg.run_id)

    dump_cfg(cfg)
    train_loader = create_loader_distributed(
        rank,
        world_size,
        shared_graph_dataset_list,
        key_to_idx,
        saint_samplers,
        node_dataset_loaded_list_train,
    )

    model = create_model(to_device=False).to(rank)
    model = CustomDDP(model, device_ids=[rank], find_unused_parameters=True)

    for key in optim_lr_dict:
        optim_lr_dict[key] = optim_lr_dict[key] * world_size * cfg.train.batch_size

    optimizer = create_optimizer(
        create_param_groups(
            model,
            nodes_num_dict,
            True,
            min_lr=cfg.optim.dataset_min_lr * world_size * cfg.train.batch_size,
            max_lr=cfg.optim.dataset_max_lr * world_size * cfg.train.batch_size,
            base_lr=cfg.optim.base_lr * world_size * cfg.train.batch_size,
            lr_dict=optim_lr_dict,
        ),
        new_optimizer_config(cfg),
    )

    # print lr of base mod and dataset specific mods
    lr_log = {}
    logged_names = set()  # Set to keep track of logged names

    if cfg.rank == 0:
        for group in optimizer.param_groups:
            group_name = group.get(
                "name", "base_perceiver"
            )  # Ensure there's a default name
            # Check if the name has already been logged
            if group_name not in logged_names:
                lr_log[f"LR/{group_name}"] = group["lr"]
                logged_names.add(group_name)  # Mark this name as logged
        # Create a PrettyTable
        table = PrettyTable()
        table.field_names = ["Group Name", "Learning Rate"]

        # Fill the table with data
        for name, lr in lr_log.items():
            table.add_row([name.split("/")[-1], lr])
        print(table)

    scheduler = create_scheduler(optimizer, new_scheduler_config(cfg))

    if rank == 0:
        cfg.params = params_count(model)
        print("Number of parameters: ", cfg.params)
        print("Number of parameters without MLP: ", params_count_no_mlp(model))
    if len(cfg.dataset_multi.name_list) > 0:
        train_dict[cfg.train.mode](
            train_loader, model, optimizer, scheduler, start_epoch=0
        )
    else:
        train_dict[cfg.train.mode](train_loader, model, optimizer, scheduler)
    cleanup()


if __name__ == "__main__":
    # main(0, 1)
    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (32368, rlimit[1]))
    torch.multiprocessing.set_sharing_strategy("file_system")

    world_size = torch.cuda.device_count()

    args = parse_args()
    # Load config file
    set_cfg(cfg)
    load_cfg(cfg, args)

    # Initlizing memory manager
    manager = mp.Manager()

    # Load necessary configs of all datasets
    cfg.dataset_config_list = []
    graph_dataset_dict = {}
    graph_ratio_in_epoch = {}
    optim_lr_dict = {}
    for dataset_name in cfg.dataset_multi.name_list:
        dataset_cfg = getattr(cfg, dataset_name)
        dataset_cfg.enable = True
        dataset_cfg.dir = f"{cfg.out_dir}/graph-datasets/real_datasets"
        cfg.dataset_config_list.append(dataset_cfg)
        graph_dataset_dict[dataset_cfg.dataset_name] = dataset_cfg
        graph_ratio_in_epoch[dataset_cfg.dataset_name] = (
            dataset_cfg.graph_ratio_in_epoch
        )
        optim_lr_dict[dataset_cfg.dataset_name] = dataset_cfg.manual_lr

    # Create a list in the shared memory
    shared_graph_dataset_list = manager.list()
    shared_graph_dataset_keys = []
    nodes_num_dict = {}
    shared_graph_ratio_in_epoch = {}

    # Create a PrettyTable object for printing datasets stats
    table = PrettyTable()

    table.field_names = [
        "Index",
        "Dataset Name",
        "Number of Graphs",
        "Number of Nodes",
        "Feat Dim",
        "Num Classes",
        "Graph Ratio",
    ]
    counter = 0
    # load all datasets in shared memory
    for key, value in graph_dataset_dict.items():
        graph_dataset = load_dataset_from_pt(value)
        if value.format == "Network_repository":
            num_graphs = len(graph_dataset)
            num_nodes = 0
            for idx in range(len(graph_dataset)):
                single_graph = graph_dataset[idx]
                single_graph.x = single_graph.x.float()
                if isinstance(single_graph.dataset_name, str):
                    single_graph.dataset_name = [single_graph.dataset_name]
                    single_graph.dataset_task_name = [single_graph.dataset_task_name]
                if len(single_graph.x.shape) == 1:
                    single_graph.x = single_graph.x.unsqueeze(-1)

                num_nodes += single_graph.x.shape[0]
                shared_graph_dataset_keys.append(f"{key}_graph_{idx}")
                shared_graph_dataset_list.append(single_graph)
                shared_graph_ratio_in_epoch[f"{key}_graph_{idx}"] = (
                    graph_ratio_in_epoch[key]
                )
            nodes_num_dict[key] = num_nodes
            num_nodes = num_nodes / num_graphs
            counter += 1
            table.add_row(
                [
                    counter,
                    graph_dataset_dict[key].dataset_name,
                    num_graphs,
                    num_nodes,
                    graph_dataset_dict[key].feat_dim,
                    graph_dataset_dict[key].task_dim,
                    graph_dataset_dict[key].graph_ratio_in_epoch,
                ]
            )

        else:
            single_graph = graph_dataset.data
            feat_dim = single_graph.x.shape[1]
            task_dim = single_graph.y.max().item() + 1
            counter += 1
            table.add_row(
                [
                    counter,
                    graph_dataset_dict[key].dataset_name,
                    num_graphs,
                    single_graph.num_nodes,
                    feat_dim,
                    task_dim,
                    graph_dataset_dict[key].graph_ratio_in_epoch,
                ]
            )
            nodes_num_dict[key] = single_graph.num_nodes
            shared_graph_dataset_list.append(single_graph)
            shared_graph_dataset_keys.append(f"{key}")
            shared_graph_ratio_in_epoch[key] = graph_ratio_in_epoch[key]

    syn_graph_config = []
    # load synthetic data and make its config and assign to syntheitc data stuff
    if cfg.dataset_multi.use_synthetic:

        syntheitc_pt_files = glob.glob(
            cfg.dataset_multi.synthetic_data_dir[0] + "*/eigen_processed/*"
        )
        print("loading synthetic graphs: ", len(syntheitc_pt_files))
        for filename in syntheitc_pt_files:
            dataset = load_dataset_from_pt_syn(filename)
            dataset_name = dataset.data.dataset_name[0]
            dataset.data.x = dataset.data.x.float()
            dataset.data.pos_type = torch.full(
                (len(dataset.data.eigvecs_sn),), 1, dtype=torch.long
            )
            syn_graph_config.append(
                {
                    "dataset_name": dataset_name,
                    "task": "node",
                    "task_type": "classification",
                    "loss_fun": "cross_entropy",
                    "task_dim": dataset.data.config["num_clusters"],
                    "feat_dim": dataset.data.config["feature_dim"],
                    "node_feat_encoder_name": "LinearNode",
                    "graph_ratio_in_epoch": 0.05,
                }
            )
            shared_graph_dataset_list.append(dataset.data)
            shared_graph_dataset_keys.append(f"{dataset_name}")
            # shared_graph_dataset_dict[dataset_name] = dataset.data
            shared_graph_ratio_in_epoch[dataset_name] = 0.05
            nodes_num_dict[dataset_name] = dataset.data.num_nodes

    cfg.dataset_multi.syn_graph_config = syn_graph_config

    # Print the table
    print(table)
    print(f"Total number of graphs: {len(shared_graph_dataset_list)}")

    saint_samplers = {}

    key_to_idx = {key: idx for idx, key in enumerate(shared_graph_dataset_keys)}
    saint_samplers = manager.dict(
        {
            key: GraphSAINTRandomWalkSampler_custom(
                shared_graph_dataset_list[key_to_idx[key]],
                batch_size=cfg.model.hop_cutoff,
                walk_length=cfg.model.hop_cutoff,
                num_steps=5,
                sample_coverage=0,
                save_dir=None,
            )
            for i, key in enumerate(key_to_idx.keys())
        }
    )

    node_dataset_loaded_list_train = []
    node_dataset_loaded_list_test = []

    for key in shared_graph_dataset_keys:

        graph = shared_graph_dataset_list[key_to_idx[key]]
        # uncomment to use 100% of labels
        if hasattr(shared_graph_dataset_list[key_to_idx[key]], "train_mask"):
            if len(shared_graph_dataset_list[key_to_idx[key]].y.shape) == 2:
                dataset_mask = torch.ones_like(
                    shared_graph_dataset_list[key_to_idx[key]]["train_mask"],
                    dtype=torch.bool,
                )
            else:
                dataset_mask = torch.ones_like(
                    shared_graph_dataset_list[key_to_idx[key]].y, dtype=torch.bool
                )
                dataset_mask[shared_graph_dataset_list[key_to_idx[key]].y == -1] = False
        else:
            dataset_mask = torch.ones_like(
                shared_graph_dataset_list[key_to_idx[key]].y, dtype=torch.bool
            )
            dataset_mask[shared_graph_dataset_list[key_to_idx[key]].y == -1] = False

        if len(graph.y[dataset_mask]) > 0:
            dict_for_node_dataset = {
                "node_id": graph.node_id[dataset_mask],
                "y": graph.y[dataset_mask],
                "node_dataset_name": graph.dataset_name[0],
                "dataset_task_name": graph.dataset_task_name[0],
                "num_nodes": graph.num_nodes,
                "main_graph_dict_key": key,
                "graph_ratio_in_epoch": shared_graph_ratio_in_epoch[key],
            }
            node_dataset_loaded_list_train.append(
                NodeDataset(dict_for_node_dataset, mask="train_mask")
            )

    mp.spawn(
        main,
        args=(
            world_size,
            shared_graph_dataset_list,
            key_to_idx,
            saint_samplers,
            node_dataset_loaded_list_train,
            nodes_num_dict,
            syn_graph_config,
            optim_lr_dict,
        ),
        nprocs=world_size,
    )
