import os
import pickle
import logging

import numpy as np
import ase
from tqdm import tqdm
import torch

from schnetpack.diffusion.bonds import *
from schnetpack import properties


def generate_bonds_data(save_path: str = None, overwrite: bool = False):

    save_path = save_path or "./bonds.pkl"

    if os.path.exists(save_path) and not overwrite:
        logging.info("Bonds data already exists, skipping generation and reloading...")
        with open(save_path, "rb") as f:
            return pickle.load(f)

    atoms = np.array(ase.data.chemical_symbols)
    indices = np.arange(len(atoms))
    m_bonds_1 = np.ones((len(atoms), len(atoms))) * -np.inf
    m_bonds_2 = m_bonds_1.copy()
    m_bonds_3 = m_bonds_1.copy()
    allowed_bonds = np.zeros((len(atoms)), dtype=np.int32)
    for at in atoms:
        for at2 in atoms:
            if at in bonds1 and at2 in bonds1[at]:
                m_bonds_1[indices[atoms == at], indices[atoms == at2]] = (
                    bonds1[at][at2] / 100.0
                )
            if at in bonds2 and at2 in bonds2[at]:
                m_bonds_2[indices[atoms == at], indices[atoms == at2]] = (
                    bonds2[at][at2] / 100.0
                )
            if at in bonds3 and at2 in bonds3[at]:
                m_bonds_3[indices[atoms == at], indices[atoms == at2]] = (
                    bonds3[at][at2] / 100.0
                )
        if at in allowed_bonds_dict:
            allowed_bonds[indices[atoms == at]] = allowed_bonds_dict[at]

    data = {
        "bonds_1": m_bonds_1,
        "bonds_2": m_bonds_2,
        "bonds_3": m_bonds_3,
        "allowed_bonds": allowed_bonds,
    }

    with open(save_path, "wb") as f:
        pickle.dump(data, f)

    return data


def squared_euclidean_distance(a, b):
    distance = (
        (a**2).sum(axis=1)[:, None] - 2 * np.dot(a, b.T) + (b**2).sum(axis=1)[None]
    )

    return np.where(distance < 0, np.zeros(distance.shape), distance)


def check_validity(
    inputs, m_bonds_1, m_bonds_2, m_bonds_3, allowed_bonds, progress_bar=True
):
    bonds = []
    stable_atoms = []
    stable_molecules = []
    stable_atoms_wo_h = []
    stable_molecules_wo_h = []
    connected = []
    connected_wo_h = []

    if properties.idx_m not in inputs:
        inputs[properties.idx_m] = torch.zeros(
            len(inputs[properties.Z]), dtype=torch.int32
        )
        progress_bar = False
    for m in tqdm(inputs[properties.idx_m].unique(), disable=not progress_bar):
        mask = inputs[properties.idx_m] == m
        R = inputs[properties.R][mask]
        Z = inputs[properties.Z][mask]
        if torch.is_tensor(R):
            R = R.detach().cpu().numpy()
        if torch.is_tensor(Z):
            Z = Z.detach().cpu().numpy()

        ex_bonds_1 = m_bonds_1[Z[None], Z[:, None]]
        ex_bonds_2 = m_bonds_2[Z[None], Z[:, None]]
        ex_bonds_3 = m_bonds_3[Z[None], Z[:, None]]

        # compuite distance matrix
        dist = squared_euclidean_distance(R, R) ** 0.5
        np.fill_diagonal(dist, np.inf)

        # get bond types per atom
        bonds_ = np.where(dist < ex_bonds_1 + 0.1, 1, 0)
        bonds_ = np.where(dist < ex_bonds_2 + 0.05, 2, bonds_)
        bonds_ = np.where(dist < ex_bonds_3 + 0.03, 3, bonds_)

        bonds.append(bonds_)

        # check if molecule is stable
        total_bonds = bonds_.sum(1)  # + bonds_.sum(0)
        stable_at = allowed_bonds[Z] == total_bonds
        stable_atoms.append(stable_at)
        stable_molecules.append(stable_at.all())

        # check if molecule is stable without hydrogen
        stable_at[Z == 1] = True
        stable_atoms_wo_h.append(stable_at)
        stable_molecules_wo_h.append(stable_at.all())

        # check if molecule is connected
        bonds_t = (bonds[-1]) + np.eye(bonds[-1].shape[0])  # + (bonds[-1]).T
        bonds_t = bonds_t > 0
        for i in range(bonds_t.shape[0]):
            bonds_t = bonds_t.dot(bonds_t)
        connected.append(bonds_t.all(1).any())

        # check if molecule is connected without hydrogen
        bonds_t[:, Z == 1] = True
        connected_wo_h.append(bonds_t.all(1).any())

    return (
        bonds,
        stable_atoms,
        stable_molecules,
        connected,
        stable_atoms_wo_h,
        stable_molecules_wo_h,
        connected_wo_h,
    )
