import os
from os.path import join
import json

import numpy as np
import rdkit.Chem as Chem
import torch
from tqdm import tqdm

from bond_type_prediction.utils import RunningAverage


atom_decoder = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 35: 'Br', 53: 'I'}


def get_optimal_bond_distances(data_dir, split='train', use_running_average=False, enforce_recomputation=False):
    """
    Args:
        data_dir (str): path to data folder. e.g. data/qm9
        split (str): which split to use in order to compute the stats
    """
    json_path = join(data_dir, f'average_bonds_{split}.json')
    if os.path.exists(json_path) and not enforce_recomputation:
        with open(json_path, 'r') as fp:
            combined_bonds = json.load(fp)
        return combined_bonds

    # Read SMILES file
    smiles_path = join(data_dir, f'smiles/{split}.txt')
    with open(smiles_path, "rb") as file:
        all_smiles = file.readlines()
    all_smiles = [smiles.decode("utf-8").strip("\n") for smiles in all_smiles]

    # initalize bonds dicts
    bonds1 = {}
    bonds2 = {}
    bonds3 = {}

    for idx in tqdm(range(len(all_smiles))):
        # construct current 2D molecule to get true bonds
        smiles = all_smiles[idx]
        try:
            mol = Chem.MolFromSmiles(smiles)
            Chem.Kekulize(mol)
        except:
            continue
        mol = Chem.AddHs(mol)
        adj_matrix = Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True)

        # read current 3D molecule to compute stats
        xyz_path = join(data_dir, f'synthetic_coords_rdkit/{split}/mol_{idx}.xyz')
        with open(xyz_path, 'r') as file:
            molecule = process_xyz_no_smiles(file)
        positions, atom_type = molecule['positions'], molecule['atomic_numbers']

        # compute pairwise distances
        distance_matrix = torch.cdist(positions, positions, p=2)

        # only iterate through existing edges
        edges = np.where(adj_matrix!=0)
        for edge in zip(edges[0], edges[1]):
            i = edge[0]
            j = edge[1]
            atom_i = atom_decoder[atom_type[i].item()]
            atom_j = atom_decoder[atom_type[j].item()]
            bond_type = int(adj_matrix[i, j])
            dist = distance_matrix[i, j].item() * 100 # multiply by 100 to get the right scale as in the other file

            if bond_type == 1:
                update_bonds_dict(bonds1, atom_i, atom_j, dist, use_running_average)
            elif bond_type == 2:
                update_bonds_dict(bonds2, atom_i, atom_j, dist, use_running_average)
            elif bond_type == 3:
                update_bonds_dict(bonds3, atom_i, atom_j, dist, use_running_average)
            else:
                raise Exception("Something is wrong")

    prepare_bonds_dict(bonds1, use_running_average)
    prepare_bonds_dict(bonds2, use_running_average)
    prepare_bonds_dict(bonds3, use_running_average)
    combined_bonds = {'bonds1': bonds1, 'bonds2': bonds2, 'bonds3': bonds3}

    with open(json_path, 'w') as fp:
        json.dump(combined_bonds, fp)

    return combined_bonds
        

# ZINC250k but includes QM9
symbol_to_atomic_number = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'P': 15, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53}

def process_xyz_no_smiles(datafile, remove_h=False):
    """
    Read xyz file and return a molecular dict with number of atoms, coordinates and atom-type for the requested dataset.

    Parameters
    ----------
    datafile : python file object
        File object containing the molecular data in the MD17 dataset.

    Returns
    -------
    molecule : dict
        Dictionary containing the molecular properties of the associated file object.

    Notes
    -----
    TODO : Replace breakpoint with a more informative failure?
    """
    xyz_lines = datafile.readlines()

    if remove_h:
        xyz_lines = [line for line in xyz_lines if not line.startswith('H')]

    num_atoms = int(xyz_lines[0])
    mol_xyz = xyz_lines[2:num_atoms+2]

    atomic_numbers, positions, formal_charges = [], [], []
    for line in mol_xyz:
        atom_symbol, posx, posy, posz, formal_charge = line.replace('*^', 'e').split()
        atomic_numbers.append(symbol_to_atomic_number[atom_symbol])
        positions.append([float(posx), float(posy), float(posz)])
        formal_charges.append(int(formal_charge))

    molecule = {'num_atoms': num_atoms, 'atomic_numbers': atomic_numbers, 'positions': positions, 'formal_charges': formal_charges}
    molecule = {key: torch.tensor(val) for key, val in molecule.items()}
    molecule['smiles'] = xyz_lines[1]

    return molecule


def update_bonds_dict(bonds, atom_1, atom_2, dist, use_running_average=False):
    if use_running_average:
        if atom_1 not in bonds:
            bonds[atom_1] = {}
        if atom_2 not in bonds[atom_1]:
            bonds[atom_1][atom_2] = RunningAverage()

        bonds[atom_1][atom_2].update(dist)
    else:
        if atom_1 not in bonds:
            bonds[atom_1] = {}
        if atom_2 not in bonds[atom_1]:
            bonds[atom_1][atom_2] = []

        bonds[atom_1][atom_2].append(dist)


def prepare_bonds_dict(bonds, use_running_average=False):
    if use_running_average:
        for atom_1 in bonds:
            for atom_2 in bonds[atom_1]:
                bonds[atom_1][atom_2] = bonds[atom_1][atom_2].get()
    else:
        for atom_1 in bonds:
            for atom_2 in bonds[atom_1]:
                bonds[atom_1][atom_2] = {'mean': np.mean(bonds[atom_1][atom_2]), 'std': np.std(bonds[atom_1][atom_2], ddof=1)}


def get_bond_order_average_model(atom1, atom2, distance, combined_bonds, use_std_as_margins=False, std_coeff=1, margins=None, check_exists=True):
    distance = 100 * distance  # We change the metric

    bonds1, bonds2, bonds3 = combined_bonds['bonds1'], combined_bonds['bonds2'], combined_bonds['bonds3']

    # Check exists for large molecules where some atom pairs do not have a
    # typical bond length.
    if check_exists:
        if atom1 not in bonds1:
            return 0
        if atom2 not in bonds1[atom1]:
            return 0

    if use_std_as_margins:
        if distance < bonds1[atom1][atom2]['mean'] + std_coeff * bonds1[atom1][atom2]['std']:

            # Check if atoms in bonds2 dictionary.
            if atom1 in bonds2 and atom2 in bonds2[atom1]:
                thr_bond2 = bonds2[atom1][atom2]['mean'] + std_coeff * bonds2[atom1][atom2]['std']
                if distance < thr_bond2:
                    if atom1 in bonds3 and atom2 in bonds3[atom1]:
                        thr_bond3 = bonds3[atom1][atom2]['mean'] + std_coeff * bonds3[atom1][atom2]['std']
                        if distance < thr_bond3:
                            return 3        # Triple
                    return 2            # Double
            return 1                # Single
        return 0                    # No bond
    else:
        assert margins is not None, "You need to specify margins in order to run bond prediction with custom margins"
        margin1, margin2, margin3 = margins
        if distance < bonds1[atom1][atom2]['mean'] + margin1:

            # Check if atoms in bonds2 dictionary.
            if atom1 in bonds2 and atom2 in bonds2[atom1]:
                thr_bond2 = bonds2[atom1][atom2]['mean'] + margin2
                if distance < thr_bond2:
                    if atom1 in bonds3 and atom2 in bonds3[atom1]:
                        thr_bond3 = bonds3[atom1][atom2]['mean'] + margin3
                        if distance < thr_bond3:
                            return 3        # Triple
                    return 2            # Double
            return 1                # Single
        return 0                    # No bond


