import os
from collections.abc import Iterable
from typing import Optional

import numpy as np
import warnings
import torch
import torch_geometric.utils
import omegaconf
import swanlab
from omegaconf import OmegaConf, open_dict
from torch_geometric.utils import to_dense_adj, to_dense_batch


_NUMPY_ALIAS_MAP = {
    "float": np.float64,
    "float_": np.float64,
    "complex": np.complex128,
    "complex_": np.complex128,
    "bool": np.bool_,
    "bool_": np.bool_,
    "int": np.int64,
    "object": object,
    "Inf": np.inf,
    "Infinity": np.inf,
    "infty": np.inf,
    "NaN": np.nan,
    "NAN": np.nan,
}


def ensure_legacy_aliases():
    for alias, target in _NUMPY_ALIAS_MAP.items():
        if alias not in np.__dict__:
            setattr(np, alias, target)


def create_folders(args):
    try:
        os.makedirs("graphs")
        os.makedirs("chains")
    except OSError:
        pass

    try:
        os.makedirs("graphs/" + args.general.name)
        os.makedirs("chains/" + args.general.name)
    except OSError:
        pass


def normalize(X, E, y, norm_values, norm_biases, node_mask):
    X = (X - norm_biases[0]) / norm_values[0]
    E = (E - norm_biases[1]) / norm_values[1]
    y = (y - norm_biases[2]) / norm_values[2]

    diag = (
        torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    )
    E[diag] = 0

    return PlaceHolder(X=X, E=E, y=y).mask(node_mask)


def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False):
    X = X * norm_values[0] + norm_biases[0]
    E = E * norm_values[1] + norm_biases[1]
    y = y * norm_values[2] + norm_biases[2]

    return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse)


def symmetrize_and_mask_diag(E):
    upper_triangular_mask = torch.zeros_like(E)
    indices = torch.triu_indices(row=E.size(1), col=E.size(2), offset=1)
    if len(E.shape) == 4:
        upper_triangular_mask[:, indices[0], indices[1], :] = 1
    else:
        upper_triangular_mask[:, indices[0], indices[1]] = 1
    E = E * upper_triangular_mask
    E = E + torch.transpose(E, 1, 2)
    diag = (
        torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    )
    E[diag] = 0

    return E


def to_dense(x, edge_index, edge_attr, batch):
    X, node_mask = to_dense_batch(x=x, batch=batch)
    edge_index, edge_attr = torch_geometric.utils.remove_self_loops(
        edge_index, edge_attr
    )
    max_num_nodes = X.size(1)
    E = to_dense_adj(
        edge_index=edge_index,
        batch=batch,
        edge_attr=edge_attr,
        max_num_nodes=max_num_nodes,
    )
    E = encode_no_edge(E)

    return PlaceHolder(X=X, E=E, y=None), node_mask


def encode_no_edge(E):
    assert len(E.shape) == 4
    if E.shape[-1] == 0:
        return E
    no_edge = torch.sum(E, dim=3) == 0
    first_elt = E[:, :, :, 0]
    first_elt[no_edge] = 1
    E[:, :, :, 0] = first_elt
    diag = (
        torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    )
    E[diag] = 0
    return E


def update_config_with_new_keys(cfg, saved_cfg):
    saved_general = saved_cfg.general
    saved_train = saved_cfg.train
    saved_model = saved_cfg.model

    for key, val in saved_general.items():
        OmegaConf.set_struct(cfg.general, True)
        with open_dict(cfg.general):
            if key not in cfg.general.keys():
                setattr(cfg.general, key, val)

    OmegaConf.set_struct(cfg.train, True)
    with open_dict(cfg.train):
        for key, val in saved_train.items():
            if key not in cfg.train.keys():
                setattr(cfg.train, key, val)

    OmegaConf.set_struct(cfg.model, True)
    with open_dict(cfg.model):
        for key, val in saved_model.items():
            if key not in cfg.model.keys():
                setattr(cfg.model, key, val)
    return cfg


class PlaceHolder:
    def __init__(self, X, E, y):
        self.X = X
        self.E = E
        self.y = y
        
    def to(self, *args, **kwargs):
        moved_X = self.X.to(*args, **kwargs) if self.X is not None else None
        moved_E = self.E.to(*args, **kwargs) if self.E is not None else None
        moved_y = self.y.to(*args, **kwargs) if self.y is not None else None
        
        return self.__class__(X=moved_X, E=moved_E, y=moved_y)

    def type_as(self, x: torch.Tensor):
        self.X = self.X.type_as(x)
        self.E = self.E.type_as(x)
        self.y = self.y.type_as(x)
        return self

    def to_device(self, device):
        self.X = self.X.to(device)
        self.E = self.E.to(device)
        self.y = self.y.to(device) if self.y is not None else None
        return self

    def mask(self, node_mask, collapse=False):
        x_mask = node_mask.unsqueeze(-1)            
        e_mask1 = x_mask.unsqueeze(2)               
        e_mask2 = x_mask.unsqueeze(1)               

        if collapse:
            self.X = torch.argmax(self.X, dim=-1)
            self.E = torch.argmax(self.E, dim=-1)

            self.X[node_mask == 0] = -1
            self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1
        else:
            self.X = self.X * x_mask
            self.E = self.E * e_mask1 * e_mask2
            assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
        return self

    def __repr__(self):
        return (
            f"X: {self.X.shape if type(self.X) == torch.Tensor else self.X} -- "
            + f"E: {self.E.shape if type(self.E) == torch.Tensor else self.E} -- "
            + f"y: {self.y.shape if type(self.y) == torch.Tensor else self.y}"
        )

    def split(self, node_mask):
        graph_list = []
        batch_size = self.X.shape[0]
        for i in range(batch_size):
            n = torch.sum(node_mask[i], dim=0)
            x = self.X[i, :n]
            e = self.E[i, :n, :n]
            y = self.y[i] if self.y is not None else None
            graph_list.append(PlaceHolder(X=x, E=e, y=y))
        return graph_list


def setup_swanlab(cfg):
    if cfg.general.test_only is not None:
        print("Skipping SwanLab initialization in test mode")
        return
    
    config_dict = omegaconf.OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True
    )
    if cfg.general.test_only is None:
        name = f"{cfg.general.name}"
    else:
        if cfg.sample.search:
            name = f"{cfg.general.name}_search_{cfg.sample.search}"
        else:
            name = f"{cfg.general.name}_{cfg.sample.time_distortion}"
    kwargs = {
        "name": name,
        "project": f"graph_dfm_{cfg.dataset.name}",
        "config": config_dict,
        "settings": swanlab.Settings(_disable_stats=True),
        "reinit": True,
        "mode": cfg.general.swanlab,
    }
    config_dict["general"]["local_dir"] = os.getcwd()
    
    try:
        swanlab.init(**kwargs)
    except Exception as e:
        print(f"Warning: Failed to initialize SwanLab: {e}")
        print("Continuing without SwanLab logging...")
        if cfg.general.swanlab != "offline":
            try:
                kwargs["mode"] = "offline"
                swanlab.init(**kwargs)
                print("SwanLab initialized in offline mode")
            except:
                pass


