import os
from omegaconf import OmegaConf, open_dict
import torch
import torch_geometric.utils
from torch_geometric.utils import to_dense_adj, to_dense_batch
import numpy as np
import math
import networkx as nx


def create_folders(args):
    try:
        os.makedirs('graphs')
        os.makedirs('chains')
    except OSError:
        pass

    try:
        os.makedirs('graphs/' + args.general.name)
        os.makedirs('chains/' + args.general.name)
    except OSError:
        pass


def normalize(X, E, y, norm_values, norm_biases, node_mask):
    X = (X - norm_biases[0]) / norm_values[0]
    E = (E - norm_biases[1]) / norm_values[1]
    y = (y - norm_biases[2]) / norm_values[2]

    diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    E[diag] = 0

    return PlaceHolder(X=X, E=E, y=y).mask(node_mask)


def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False):
    """
    X : node features
    E : edge features
    y : global features`
    norm_values : [norm value X, norm value E, norm value y]
    norm_biases : same order
    node_mask
    """
    X = (X * norm_values[0] + norm_biases[0])
    E = (E * norm_values[1] + norm_biases[1])
    y = y * norm_values[2] + norm_biases[2]

    return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse)


def to_sparse_edges(X, E, batch, separate=False):
    edge_index_list = []
    edge_attr_list = []
    if separate:
        edge_index_ori_list = []
    for b in range(E.shape[0]):
        row, col = E[b, :, :, 1:].sum(-1).nonzero(as_tuple=True) # [1,0,0,0,0] is the type of no bond
        edge_index = torch.stack([row, col], dim=0)
        edge_feature = E[b][row, col]
        edge_attr_list.append(edge_feature)
        
        if separate:
            edge_index_ori_list.append(edge_index)
            
        if b > 0:
            last_idx = torch.arange(X.shape[0])[batch.cpu()==b-1][-1].item()
            edge_index = edge_index + last_idx + 1
        edge_index_list.append(edge_index)
    
    if separate:
        return torch.cat(edge_index_list, dim=1), torch.cat(edge_attr_list, dim=0), edge_index_ori_list
    return torch.cat(edge_index_list, dim=1), torch.cat(edge_attr_list, dim=0)


def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None):
    X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes)
    # node_mask = node_mask.float()
    edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
    if max_num_nodes is None:
        max_num_nodes = X.size(1)
    E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
    E = encode_no_edge(E)
    return PlaceHolder(X=X, E=E, y=None), node_mask


def encode_no_edge(E):
    assert len(E.shape) == 4
    if E.shape[-1] == 0:
        return E
    no_edge = torch.sum(E, dim=3) == 0
    first_elt = E[:, :, :, 0]
    first_elt[no_edge] = 1
    E[:, :, :, 0] = first_elt
    diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    E[diag] = 0
    return E


def update_config_with_new_keys(cfg, saved_cfg):
    saved_general = saved_cfg.general
    saved_train = saved_cfg.train
    saved_model = saved_cfg.model
    saved_dataset = saved_cfg.dataset
    
    for key, val in saved_dataset.items():
        OmegaConf.set_struct(cfg.dataset, True)
        with open_dict(cfg.dataset):
            if key not in cfg.dataset.keys():
                setattr(cfg.dataset, key, val)

    for key, val in saved_general.items():
        OmegaConf.set_struct(cfg.general, True)
        with open_dict(cfg.general):
            if key not in cfg.general.keys():
                setattr(cfg.general, key, val)

    OmegaConf.set_struct(cfg.train, True)
    with open_dict(cfg.train):
        for key, val in saved_train.items():
            if key not in cfg.train.keys():
                setattr(cfg.train, key, val)

    OmegaConf.set_struct(cfg.model, True)
    with open_dict(cfg.model):
        for key, val in saved_model.items():
            if key not in cfg.model.keys():
                setattr(cfg.model, key, val)
    return cfg


class PlaceHolder:
    def __init__(self, X, E=None, y=None):
        self.X = X
        self.E = E
        self.y = y

    def type_as(self, x: torch.Tensor, categorical: bool = False):
        """ Changes the device and dtype of X, E, y. """
        self.X = self.X.type_as(x)
        self.E = self.E.type_as(x)
        if categorical:
            self.y = self.y.type_as(x)
        return self

    def mask(self, node_mask, collapse=False):
        x_mask = node_mask.unsqueeze(-1)          # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)             # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)             # bs, 1, n, 1

        if collapse:
            self.X = torch.argmax(self.X, dim=-1)
            self.E = torch.argmax(self.E, dim=-1)

            self.X[node_mask == 0] = - 1
            self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
        else:
            self.X = self.X * x_mask
            self.E = self.E * e_mask1 * e_mask2
            try:
                assert torch.allclose(self.E, torch.transpose(self.E, 1, 2), rtol=1e-5, atol=1e-8)
            except:
                print("allclose not satisfied")
        return self


def mols_to_nx(mols):
    import networkx as nx
    nx_graphs = []
    for mol in mols:
        G = nx.Graph()

        for atom in mol.GetAtoms():
            G.add_node(atom.GetIdx(),
                       label=atom.GetSymbol())
                    
        for bond in mol.GetBonds():
            G.add_edge(bond.GetBeginAtomIdx(),
                       bond.GetEndAtomIdx(),
                       label=int(bond.GetBondTypeAsDouble()))
                    #    bond_type=bond.GetBondType())
        
        nx_graphs.append(G)
    return nx_graphs


def equal_mass_binning(values, num_bins):
    values = np.asarray(values)
    quantiles = np.linspace(0, 100, num_bins + 1)
    bin_edges = np.percentile(values, quantiles)
    # Add small epsilon to avoid right-edge inclusion errors
    bin_edges[-1] += 1e-8
    bin_indices = np.digitize(values, bins=bin_edges, right=False) - 1
    return bin_indices


class ExponentialTempScheduler:
    """
    Exponentially anneals temperature from init_temp → min_temp over total_steps:
      τ_t = max(min_temp, init_temp * exp(-decay * t))
    where decay = -ln(min_temp / init_temp) / total_steps
    """
    def __init__(self, init_temp: float = 2.0, min_temp: float = 0.8, total_steps: int = 1000):
        self.init_temp = init_temp
        self.min_temp = min_temp
        # solve for decay such that at t=total_steps, temp ≈ min_temp
        self.decay = -math.log(min_temp / init_temp) / total_steps
        self.step_count = 0
        self.current_temp = init_temp

    def update(self) -> float:
        """Advance one step and return the new temperature."""
        self.step_count += 1
        temp = self.init_temp * math.exp(-self.decay * self.step_count)
        self.current_temp = max(self.min_temp, temp)
        return self.current_temp


def graph_to_nx(X, E, num_nodes):
    import networkx as nx
    X = X.detach().cpu()
    E = E.detach().cpu()
    num_nodes = num_nodes.detach().cpu()

    B = X.shape[0]
    nx_graphs = []
    for b in range(B):
        N = num_nodes[b].item()
        G = nx.Graph()
        for i in range(N):
            G.add_node(i, label=X[b, i].argmax().item())
        for i in range(N):
            for j in range(N):
                if E[b, i, j][1:].sum() > 0: # first class: no edge
                    G.add_edge(i, j, label=E[b, i, j].argmax().item())
        nx_graphs.append(G)
    
    return nx_graphs


def node_flags(adj, eps=1e-5):
    flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)
    if len(flags.shape)==3:
        flags = flags[:,0,:]
    return flags


def mask_x(x, flags):
    if flags is None:
        flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
    return x * flags[:,:,None]


def mask_adjs(adjs, flags):
    """
    :param adjs:  B x N x N or B x C x N x N
    :param flags: B x N
    :return:
    """
    if flags is None:
        flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)

    if len(adjs.shape) == 4:
        flags = flags.unsqueeze(1)  # B x 1 x N
    adjs = adjs * flags.unsqueeze(-1)
    adjs = adjs * flags.unsqueeze(-2)
    return adjs


def adjs_to_graphs(adj, is_cuda=False):
    adj = adj.detach().cpu().numpy()
    G = nx.from_numpy_array(adj)
    G.remove_edges_from(nx.selfloop_edges(G))
    G.remove_nodes_from(list(nx.isolates(G)))
    if G.number_of_nodes() < 1:
        G.add_node(1)
    return G