"""Module to generate networkx graphs."""
"""Implementation based on the template of Matformer."""
from multiprocessing.context import ForkContext
from re import X
import numpy as np
import pandas as pd
from jarvis.core.specie import chem_data, get_node_attributes

# from jarvis.core.atoms import Atoms
from collections import defaultdict
from typing import List, Tuple, Sequence, Optional
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import LineGraph
from torch_geometric.data.batch import Batch
import itertools

try:
    import torch
    from tqdm import tqdm
except Exception as exp:
    print("torch/tqdm is not installed.", exp)
    pass


def angle_from_array(a, b, lattice):
    a_new = np.dot(a, lattice)
    b_new = np.dot(b, lattice)
    assert a_new.shape == a.shape
    value = sum(a_new * b_new)
    length = (sum(a_new ** 2) ** 0.5) * (sum(b_new ** 2) ** 0.5)
    cos = value / length
    angle = np.arccos(cos)
    return angle / np.pi * 180.0

def correct_coord_sys(a, b, c, lattice):
    a_new = np.dot(a, lattice)
    b_new = np.dot(b, lattice)
    c_new = np.dot(c, lattice)
    assert a_new.shape == a.shape
    plane_vec = np.cross(a_new, b_new)
    value = sum(plane_vec * c_new)
    length = (sum(plane_vec ** 2) ** 0.5) * (sum(c_new ** 2) ** 0.5)
    cos = value / length
    angle = np.arccos(cos)
    return (angle / np.pi * 180.0 <= 90.0)

def same_line(a, b):
    a_new = a / (sum(a ** 2) ** 0.5)
    b_new = b / (sum(b ** 2) ** 0.5)
    flag = False
    if abs(sum(a_new * b_new) - 1.0) < 1e-5:
        flag = True
    elif abs(sum(a_new * b_new) + 1.0) < 1e-5:
        flag = True
    else:
        flag = False
    return flag

def same_plane(a, b, c):
    flag = False
    if abs(np.dot(np.cross(a, b), c)) < 1e-5:
        flag = True
    return flag


# pyg dataset
class PygStructureDataset(torch.utils.data.Dataset):
    """Dataset of crystal DGLGraphs."""

    def __init__(
        self,
        df: pd.DataFrame,
        graphs: Sequence[Data],
        target: str,
        atom_features="atomic_number",
        transform=None,
        line_graph=False,
        classification=False,
        id_tag="jid",
        neighbor_strategy="",
        nolinegraph=False,
        mean_train=None,
        std_train=None,
    ):
        """Pytorch Dataset for atomistic graphs.

        `df`: pandas dataframe from e.g. jarvis.db.figshare.data
        `graphs`: DGLGraph representations corresponding to rows in `df`
        `target`: key for label column in `df`
        """
        self.df = df
        self.graphs = graphs
        self.target = target
        self.line_graph = line_graph

        self.ids = self.df[id_tag]
        # self.atoms = self.df['atoms']
        self.labels = torch.tensor(self.df[target]).type(
            torch.get_default_dtype()
        )
        print("mean %f std %f"%(self.labels.mean(), self.labels.std()))
        if mean_train == None:
            mean = self.labels.mean()
            std = self.labels.std()
            self.labels = (self.labels - mean) / std
            print("normalize using training mean but shall not be used here %f and std %f" % (mean, std))
        else:
            self.labels = (self.labels - mean_train) / std_train
            print("normalize using training mean %f and std %f" % (mean_train, std_train))

        self.transform = transform

        features = self._get_attribute_lookup(atom_features)

        # load selected node representation
        # assume graphs contain atomic number in g.ndata["atom_features"]
        for g in graphs:
            z = g.x
            g.atomic_number = z
            z = z.type(torch.IntTensor).squeeze()
            f = torch.tensor(features[z]).type(torch.FloatTensor)
            if g.x.size(0) == 1:
                f = f.unsqueeze(0)
            # g1, g2, g3 = g.adj.chunk(chunks=3, dim=1)
            # g.adj = torch.mm(g1, f) + torch.mm(g2, f) + torch.mm(g3, f)
            g.x = f

        self.prepare_batch = prepare_pyg_batch
        if line_graph:
            self.prepare_batch = prepare_pyg_line_graph_batch


    @staticmethod  # 表示 _get_attribute_lookup 是一个静态方法，可以通过类名直接调用，而不需要实例化对象。
    def _get_attribute_lookup(atom_features: str = "cgcnn"):
        """Build a lookup array indexed by atomic number."""
        max_z = max(v["Z"] for v in chem_data.values())  # 计算 chem_data 中所有原子的最大原子序数 max_z。

        # get feature shape (referencing Carbon)
        template = get_node_attributes("C", atom_features)

        features = np.zeros((1 + max_z, len(template)))

        for element, v in chem_data.items():
            z = v["Z"]
            x = get_node_attributes(element, atom_features)

            if x is not None:
                features[z, :] = x

        return features

    def __len__(self):
        """Get length."""
        return self.labels.shape[0]

    def __getitem__(self, idx):
        """Get StructureDataset sample."""
        g = self.graphs[idx]
        label = self.labels[idx]

        if self.line_graph:
            return g, g, g, label

        return g, label

    @staticmethod
    def collate(samples: List[Tuple[Data, torch.Tensor]]):
        """Dataloader helper to batch graphs cross `samples`."""
        graphs, labels = map(list, zip(*samples))
        batched_graph = Batch.from_data_list(graphs)
        return batched_graph, torch.tensor(labels)

    @staticmethod
    def collate_line_graph(
        samples: List[Tuple[Data, Data, torch.Tensor, torch.Tensor]]
    ):
        """Dataloader helper to batch graphs cross `samples`."""
        graphs, line_graphs, lattice, labels = map(list, zip(*samples))
        batched_graph = Batch.from_data_list(graphs)
        batched_line_graph = Batch.from_data_list(line_graphs)
        if len(labels[0].size()) > 0:
            return batched_graph, batched_line_graph, batched_line_graph, torch.stack(labels)
        else:
            return batched_graph, batched_line_graph, batched_line_graph, torch.tensor(labels)

def canonize_edge(
    src_id,
    dst_id,
    src_image,
    dst_image,
):
    """Compute canonical edge representation.

    Sort vertex ids
    shift periodic images so the first vertex is in (0,0,0) image
    """
    # store directed edges src_id <= dst_id
    if dst_id < src_id:
        src_id, dst_id = dst_id, src_id
        src_image, dst_image = dst_image, src_image

    # shift periodic images so that src is in (0,0,0) image
    if not np.array_equal(src_image, (0, 0, 0)):
        shift = src_image
        src_image = tuple(np.subtract(src_image, shift))
        dst_image = tuple(np.subtract(dst_image, shift))

    assert src_image == (0, 0, 0)

    return src_id, dst_id, src_image, dst_image


def nearest_neighbor_edges_submit(
    atoms=None,
    cutoff=8,
    max_neighbors=12,
    id=None,
    use_canonize=False,
    use_lattice=False,
    use_angle=False,
):
    """Construct k-NN edge list."""
    # returns List[List[Tuple[site, distance, index, image]]]
    lat = atoms.lattice
    all_neighbors_now = atoms.get_all_neighbors(r=cutoff)
    min_nbrs = min(len(neighborlist) for neighborlist in all_neighbors_now)

    attempt = 0
    if min_nbrs < 92:
        lat = atoms.lattice
        if cutoff < max(lat.a, lat.b, lat.c):
            r_cut = max(lat.a, lat.b, lat.c)
        else:
            r_cut = 2 * cutoff
        attempt += 1
        return nearest_neighbor_edges_submit(
            atoms=atoms,
            use_canonize=use_canonize,
            cutoff=r_cut,
            max_neighbors=max_neighbors,
            id=id,
            use_lattice=use_lattice,
        )

    edges = defaultdict(set)
    # lattice correction process
    r_cut = max(lat.a, lat.b, lat.c) + 1e-2
    all_neighbors = atoms.get_all_neighbors(r=r_cut)
    neighborlist = all_neighbors[0]
    neighborlist = sorted(neighborlist, key=lambda x: x[2])
    ids = np.array([nbr[1] for nbr in neighborlist])
    images = np.array([nbr[3] for nbr in neighborlist])
    images = images[ids == 0]
    lat1 = images[0]
    # finding lat2
    start = 1
    for i in range(start, len(images)):
        lat2 = images[i]
        if not same_line(lat1, lat2):
            start = i
            break
    # finding lat3
    for i in range(start, len(images)):
        lat3 = images[i]
        if not same_plane(lat1, lat2, lat3):
            break
    # find the invariant corner
    if angle_from_array(lat1,lat2,lat.matrix) > 90.0:
        lat2 = - lat2
    if angle_from_array(lat1,lat3,lat.matrix) > 90.0:
        lat3 = - lat3
    # find the invariant coord system
    if not correct_coord_sys(lat1, lat2, lat3, lat.matrix):
        lat1 = - lat1
        lat2 = - lat2
        lat3 = - lat3

    # if not correct_coord_sys(lat1, lat2, lat3, lat.matrix):
    #     print(lat1, lat2, lat3)
    # lattice correction end

    atom_coords = atoms.coords
    # 计算最小距离矩阵
    n_atoms = len(atom_coords)

    min_dist = np.zeros((n_atoms, 92))
    # for i in range(n_atoms):
    #     neighborlist = all_neighbors[i]  # 访问第 i 个原子的邻居列表
    #     neighborlist = sorted(neighborlist, key=lambda x: x[2])
    #     ids = np.array([nbr[1] for nbr in neighborlist])
    #     dis = np.array([nbr[2] for nbr in neighborlist])
    #     # 遍历该原子的邻居
    #     for j in range(n_atoms):
    #         edge = dis[ids == j]
    #         if len(edge) < 3:
    #             edge = np.append(edge, edge[1]+1)
    #         min_dist[i][j] = edge[0]
    #         min_dist[i][j+n_atoms] = edge[1]
    #         min_dist[i][j + 2*n_atoms] = edge[2]
    #     # 对矩阵所有非零元素取倒数
    # reciprocal_min_dist = np.zeros_like(min_dist)
    # non_zero_indices = np.where(min_dist != 0)
    # reciprocal_min_dist[non_zero_indices] = 1 / min_dist[non_zero_indices]
    # reciprocal_min_dist = np.array(reciprocal_min_dist).astype(np.float32)
    # reciprocal_min_dist = torch.tensor(reciprocal_min_dist)
    for i in range(n_atoms):
        neighborlist = all_neighbors_now[i]  # 访问第 i 个原子的邻居列表
        neighborlist = sorted(neighborlist, key=lambda x: x[2])
        # ids = np.array([nbr[1] for nbr in neighborlist])
        dis = np.array([nbr[2] for nbr in neighborlist])
        # 计算每个元素占总和的比例

        min_dist[i][:] = dis[:92]
        # atomic_index[i][:] = ids[:92]

    # 对矩阵所有非零元素取倒数
    # reciprocal_min_dist = np.zeros_like(min_dist)
    # non_zero_indices = np.where(min_dist != 0)
    # reciprocal_min_dist[non_zero_indices] = 1 / min_dist[non_zero_indices]
    reciprocal_min_dist = np.array(min_dist).astype(np.float32)
    reciprocal_min_dist = torch.tensor(reciprocal_min_dist)
    # atomic_index = np.array(atomic_index).astype(np.int64)
    # atomic_index = torch.tensor(atomic_index)


    for site_idx, neighborlist in enumerate(all_neighbors_now):

        # sort on distance
        neighborlist = sorted(neighborlist, key=lambda x: x[2])
        distances = np.array([nbr[2] for nbr in neighborlist])
        ids = np.array([nbr[1] for nbr in neighborlist])
        images = np.array([nbr[3] for nbr in neighborlist])

        # find the distance to the k-th nearest neighbor
        max_dist = distances[max_neighbors - 1]
        ids = ids[distances <= max_dist]
        images = images[distances <= max_dist]
        distances = distances[distances <= max_dist]
        for dst, image in zip(ids, images):
            src_id, dst_id, src_image, dst_image = canonize_edge(
                site_idx, dst, (0, 0, 0), tuple(image)
            )
            if use_canonize:  # True
                edges[(src_id, dst_id)].add(dst_image)
            else:
                edges[(site_idx, dst)].add(tuple(image))

        if use_lattice:
            edges[(site_idx, site_idx)].add(tuple(lat1))
            edges[(site_idx, site_idx)].add(tuple(lat2))
            edges[(site_idx, site_idx)].add(tuple(lat3))

    return edges, lat1, lat2, lat3, reciprocal_min_dist


def compute_bond_cosine(v1, v2):
    """Compute bond angle cosines from bond displacement vectors."""
    v1 = torch.tensor(v1).type(torch.get_default_dtype())
    v2 = torch.tensor(v2).type(torch.get_default_dtype())
    bond_cosine = torch.sum(v1 * v2) / (
        torch.norm(v1) * torch.norm(v2)
    )
    bond_cosine = torch.clamp(bond_cosine, -1, 1)
    return bond_cosine


def build_undirected_edgedata(
    atoms=None,
    edges={},
    a=None,
    b=None,
    c=None,
):
    """Build undirected graph data from edge set.

    edges: dictionary mapping (src_id, dst_id) to set of dst_image
    r: cartesian displacement vector from src -> dst
    """
    # second pass: construct *undirected* graph
    # import pprint
    u, v, r, l, nei, angle, atom_lat = [], [], [], [], [], [], []
    v1, v2, v3 = atoms.lattice.cart_coords(a), atoms.lattice.cart_coords(b), atoms.lattice.cart_coords(c)
    # atom_lat.append([v1, v2, v3, -v1, -v2, -v3])
    atom_lat.append([v1, v2, v3])
    for (src_id, dst_id), images in edges.items():

        for dst_image in images:
            # fractional coordinate for periodic image of dst
            dst_coord = atoms.frac_coords[dst_id] + dst_image
            # cartesian displacement vector pointing from src -> dst
            d = atoms.lattice.cart_coords(
                dst_coord - atoms.frac_coords[src_id]
            )
            for uu, vv, dd in [(src_id, dst_id, d), (dst_id, src_id, -d)]:
                u.append(uu)
                v.append(vv)
                r.append(dd)
                # nei.append([v1, v2, v3, -v1, -v2, -v3])
                nei.append([v1, v2, v3])
                # angle.append([compute_bond_cosine(dd, v1), compute_bond_cosine(dd, v2), compute_bond_cosine(dd, v3)])

    u = torch.tensor(u)  # 起始点
    v = torch.tensor(v)  # 尾点
    r = torch.tensor(np.array(r)).type(torch.get_default_dtype())  # 边的坐标分量
    l = torch.tensor(l).type(torch.int)  # 未使用
    nei = torch.tensor(np.array(nei)).type(torch.get_default_dtype())  # 每条边都加了一个笛卡尔坐标
    atom_lat = torch.tensor(np.array(atom_lat)).type(torch.get_default_dtype())  # 笛卡尔坐标
    # nei_angles = torch.tensor(angle).type(torch.get_default_dtype())
    return u, v, r, l, nei, atom_lat

def distance(pos1, pos2, lattice_matrix):
    """计算两个原子之间的距离，考虑晶格参数"""
    delta = pos2 - pos1
    delta = np.dot(delta, np.linalg.inv(lattice_matrix))
    delta -= np.round(delta)
    delta = np.dot(delta, lattice_matrix)
    dist = np.linalg.norm(delta)
    return dist

class PygGraph(object):
    """Generate a graph object."""

    def __init__(
        self,
        nodes=[],
        node_attributes=[],
        edges=[],
        edge_attributes=[],
        color_map=None,
        labels=None,
    ):
        """
        Initialize the graph object.

        Args:
            nodes: IDs of the graph nodes as integer array.

            node_attributes: node features as multi-dimensional array.

            edges: connectivity as a (u,v) pair where u is
                   the source index and v the destination ID.

            edge_attributes: attributes for each connectivity.
                             as simple as euclidean distances.
        """
        self.nodes = nodes
        self.node_attributes = node_attributes
        self.edges = edges
        self.edge_attributes = edge_attributes
        self.color_map = color_map
        self.labels = labels

    @staticmethod
    def atom_dgl_multigraph(
        atoms=None,
        neighbor_strategy="k-nearest",
        cutoff=4.0,
        max_neighbors=12,
        atom_features="cgcnn",
        max_attempts=3,
        id: Optional[str] = None,
        compute_line_graph: bool = True,
        use_canonize: bool = False,
        use_lattice: bool = False,
        use_angle: bool = False,
    ):
        # atom_coords = atoms.coords
        # 计算最小距离矩阵
        # print(len(atom_coords))




        if neighbor_strategy == "k-nearest":
            edges, a, b, c, reciprocal_min_dist = nearest_neighbor_edges_submit(
                atoms=atoms,
                cutoff=cutoff,
                max_neighbors=max_neighbors,
                id=id,
                use_canonize=use_canonize,
                use_lattice=use_lattice,
                use_angle=use_angle,
            )
            u, v, r, l, nei, atom_lat = build_undirected_edgedata(atoms, edges, a, b, c)
        else:
            raise ValueError("Not implemented yet", neighbor_strategy)


        # build up atom attribute tensor
        sps_features = []
        for ii, s in enumerate(atoms.elements):
            feat = list(get_node_attributes(s, atom_features=atom_features))
            sps_features.append(feat)
        sps_features = np.array(sps_features)
        node_features = torch.tensor(sps_features).type(
            torch.get_default_dtype()
        )
        print(len(node_features))
        # total_sum = node_features.sum().float()
        # # 计算每个元素占总和的比例
        # atomic_weight = node_features.float() / total_sum
        # atomic_weight = atomic_weight.repeat(1, n_atoms)
        # atomic_weight = np.hstack((atomic_weight.T, atomic_weight.T))
        # # 将对角线元素设为零
        # for i in range(n_atoms):
        #     atomic_weight[i][i] = 0
        # reciprocal_min_dist = reciprocal_min_dist + atomic_weight
        # row_sums = reciprocal_min_dist.sum(dim=1, keepdim=True)
        # # 将 reciprocal_min_dist 除以行和
        # reciprocal_min_dist = reciprocal_min_dist / row_sums
        #
        # # 处理 NaN 值
        # reciprocal_min_dist[torch.isnan(reciprocal_min_dist)] = 0
        # reciprocal_min_dist = reciprocal_min_dist + atomic_weight.T + atomic_weight


        atom_lat = atom_lat.repeat(node_features.shape[0], 1, 1)  # 重复原子数倍
        edge_index = torch.cat((u.unsqueeze(0), v.unsqueeze(0)), dim=0).long()
        g = Data(x=node_features, edge_index=edge_index, edge_attr=r, edge_type=l, edge_nei=nei, atom_lat=atom_lat, adj=reciprocal_min_dist)
        
        return g


def prepare_pyg_batch(
    batch: Tuple[Data, torch.Tensor], device=None, non_blocking=False
):
    """Send batched dgl crystal graph to device."""
    g, t = batch
    batch = (
        g.to(device),
        t.to(device, non_blocking=non_blocking),
    )

    return batch


def prepare_pyg_line_graph_batch(
    batch: Tuple[Tuple[Data, Data, torch.Tensor], torch.Tensor],
    device=None,
    non_blocking=False,
):
    """Send line graph batch to device.

    Note: the batch is a nested tuple, with the graph and line graph together
    """
    g, lg, lattice, t = batch
    batch = (
        (
            g.to(device),
            lg.to(device),
            lattice.to(device, non_blocking=non_blocking),
        ),
        t.to(device, non_blocking=non_blocking),
    )

    return batch

