try:
    from rdkit import Chem
    from dgd.analysis.rdkit_functions import BasicMolecularMetrics
    use_rdkit = True
except ModuleNotFoundError as e:
    use_rdkit = False
import torch
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as sp_stats
import wandb
import matplotlib
matplotlib.use('Agg')

# 'atom_decoder': ['H', 'B', 'C', 'N', 'O', 'F', 'Al', 'Si', 'P', 'S', 'Cl', 'As', 'Br', 'I', 'Hg', 'Bi'],


allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3, 'Si': 4, 'P': [3, 5],
                 'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]}


class Histogram_discrete:
    def __init__(self, name='histogram'):
        self.name = name
        self.bins = {}

    def add(self, elements):
        for e in elements:
            if e in self.bins:
                self.bins[e] += 1
            else:
                self.bins[e] = 1

    def normalize(self):
        total = 0.
        for key in self.bins:
            total += self.bins[key]
        for key in self.bins:
            self.bins[key] = self.bins[key] / total

    def plot(self, save_path=None):
        width = 1  # the width of the bars
        fig, ax = plt.subplots()
        x, y = [], []
        for key in self.bins:
            x.append(key)
            y.append(self.bins[key])

        ax.bar(x, y, width)
        plt.title(self.name)
        if save_path is not None:
            plt.savefig(save_path)
        else:
            plt.show()
        plt.close()


class Histogram_cont:
    def __init__(self, num_bins=100, range=(0., 13.), name='histogram', ignore_zeros=False):
        self.name = name
        self.bins = [0] * num_bins
        self.range = range
        self.ignore_zeros = ignore_zeros

    def add(self, elements):
        for e in elements:
            if not self.ignore_zeros or e > 1e-8:
                i = int(float(e) / self.range[1] * len(self.bins))
                i = min(i, len(self.bins) - 1)
                self.bins[i] += 1

    def plot(self, save_path=None):
        width = (self.range[1] - self.range[0])/len(self.bins)                 # the width of the bars
        fig, ax = plt.subplots()

        x = np.linspace(self.range[0], self.range[1], num=len(self.bins) + 1)[:-1] + width / 2
        ax.bar(x, self.bins, width)
        plt.title(self.name)

        if save_path is not None:
            plt.savefig(save_path)
        else:
            plt.show()
        plt.close()


    def plot_both(self, hist_b, save_path=None, wandb=None):
        ## TO DO: Check if the relation of bins and linspace is correct
        hist_a = normalize_histogram(self.bins)
        hist_b = normalize_histogram(hist_b)

        #width = (self.range[1] - self.range[0]) / len(self.bins)  # the width of the bars
        fig, ax = plt.subplots()
        x = np.linspace(self.range[0], self.range[1], num=len(self.bins) + 1)[:-1]
        ax.step(x, hist_b)
        ax.step(x, hist_a)
        ax.legend(['True', 'Learned'])
        plt.title(self.name)

        if save_path is not None:
            plt.savefig(save_path)
            if wandb is not None:
                if wandb is not None:
                    # Log image(s)
                    im = plt.imread(save_path)
                    wandb.log({save_path: [wandb.Image(im, caption=save_path)]})
        else:
            plt.show()
        plt.close()


def normalize_histogram(hist):
    hist = np.array(hist)
    prob = hist / np.sum(hist)
    return prob



def earth_mover_distance(h1, h2):
    p1 = normalize_histogram(h1)
    p2 = normalize_histogram(h2)
    distance = sp_stats.wasserstein_distance(p1, p2)
    return distance


def kl_divergence(p1, p2):
    return np.sum(p1*np.log(p1 / p2))


def kl_divergence_sym(h1, h2):
    p1 = normalize_histogram(h1) + 1e-10
    p2 = normalize_histogram(h2) + 1e-10
    kl = kl_divergence(p1, p2)
    kl_flipped = kl_divergence(p2, p1)
    return (kl + kl_flipped) / 2.


def js_divergence(h1, h2):
    p1 = normalize_histogram(h1) + 1e-10
    p2 = normalize_histogram(h2) + 1e-10
    M = (p1 + p2)/2
    js = (kl_divergence(p1, M) + kl_divergence(p2, M)) / 2
    return js


# def main_analyze_qm9(remove_h: bool, dataset_name='qm9', n_atoms=None):
#     class DataLoaderConfig(object):
#         def __init__(self):
#             self.batch_size = 128
#             self.remove_h = remove_h
#             self.filter_n_atoms = n_atoms
#             self.num_workers = 0
#             self.include_charges = True
#             self.dataset = dataset_name  #could be qm9, qm9_first_half or qm9_second_half
#             self.datadir = 'qm9/temp'
#
#     cfg = DataLoaderConfig()
#
#     dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg)
#
#     hist_nodes = Histogram_discrete('Histogram # nodes')
#     hist_atom_type = Histogram_discrete('Histogram of atom types')
#     hist_dist = Histogram_cont(name='Histogram relative distances', ignore_zeros=True)
#
#     for i, data in enumerate(dataloaders['train']):
#         print(i * cfg.batch_size)
#
#         # Histogram num_nodes
#         num_nodes = torch.sum(data['atom_mask'], dim=1)
#         num_nodes = list(num_nodes.numpy())
#         hist_nodes.add(num_nodes)
#
#         #Histogram edge distances
#         x = data['positions'] * data['atom_mask'].unsqueeze(2)
#         dist = coord2distances(x)
#         hist_dist.add(list(dist.numpy()))
#
#         # Histogram of atom types
#         one_hot = data['one_hot'].double()
#         atom = torch.argmax(one_hot, 2)
#         atom = atom.flatten()
#         mask = data['atom_mask'].flatten()
#         masked_atoms = list(atom[mask].numpy())
#         hist_atom_type.add(masked_atoms)
#
#     hist_dist.plot()
#     hist_dist.plot_both(hist_dist.bins[::-1])
#     print("KL divergence A %.4f" % kl_divergence_sym(hist_dist.bins, hist_dist.bins[::-1]))
#     print("KL divergence B %.4f" % kl_divergence_sym(hist_dist.bins, hist_dist.bins))
#     print(hist_dist.bins)
#     hist_nodes.plot()
#     print("Histogram of the number of nodes", hist_nodes.bins)
#     hist_atom_type.plot()
#     print(" Histogram of the atom types (H (optional), C, N, O, F)", hist_atom_type.bins)


############################
# Validity and bond analysis
def check_stability(atom_types, edge_types, dataset_info, debug=False):
    atom_decoder = dataset_info.atom_decoder

    n_bonds = np.zeros(len(atom_types), dtype='int')

    # TODO: not sure that this is correct

    for i in range(len(atom_types)):
        for j in range(i + 1, len(atom_types)):
            n_bonds[i] += abs((edge_types[i, j] + edge_types[j, i])/2)
            n_bonds[j] += abs((edge_types[i, j] + edge_types[j, i])/2)
    n_stable_bonds = 0
    for atom_type, atom_n_bond in zip(atom_types, n_bonds):
        possible_bonds = allowed_bonds[atom_decoder[atom_type]]
        if type(possible_bonds) == int:
            is_stable = possible_bonds == atom_n_bond
        else:
            is_stable = atom_n_bond in possible_bonds
        if not is_stable and debug:
            print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type], atom_n_bond))
        n_stable_bonds += int(is_stable)

    molecule_stable = n_stable_bonds == len(atom_types)
    return molecule_stable, n_stable_bonds, len(atom_types)


def process_loader(dataloader):
    """ Mask atoms, return positions and atom types"""
    out = []
    for data in dataloader:
        for i in range(data['positions'].size(0)):
            positions = data['positions'][i].view(-1, 3)
            one_hot = data['one_hot'][i].view(-1, 5).type(torch.float32)
            mask = data['atom_mask'][i].flatten()
            positions, one_hot = positions[mask], one_hot[mask]
            atom_type = torch.argmax(one_hot, dim=1)
            out.append((positions, atom_type))
    return out


def compute_molecular_metrics(molecule_list, train_smiles, dataset_info):
    """ molecule_list: (dict) """

    if not dataset_info.remove_h:
        print(f'Analyzing molecule stability...')

        molecule_stable = 0
        nr_stable_bonds = 0
        n_atoms = 0
        n_molecules = len(molecule_list)

        for i, mol in enumerate(molecule_list):
            atom_types, edge_types = mol

            validity_results = check_stability(atom_types, edge_types, dataset_info)

            molecule_stable += int(validity_results[0])
            nr_stable_bonds += int(validity_results[1])
            n_atoms += int(validity_results[2])

        # Validity
        fraction_mol_stable = molecule_stable / float(n_molecules)
        fraction_atm_stable = nr_stable_bonds / float(n_atoms)
        validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable}
        wandb.log(validity_dict, commit=False)
    else:
        validity_dict = {'mol_stable': -1, 'atm_stable': -1}

    metrics = BasicMolecularMetrics(dataset_info, train_smiles)
    rdkit_metrics = metrics.evaluate(molecule_list)
    wandb.log({'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1],
               'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3]}, commit=False)
    return validity_dict, rdkit_metrics


def analyze_node_distribution(mol_list, save_path):
    hist_nodes = Histogram_discrete('Histogram # nodes (stable molecules)')
    hist_atom_type = Histogram_discrete('Histogram of atom types')

    for molecule in mol_list:
        positions, atom_type = molecule
        hist_nodes.add([positions.shape[0]])
        hist_atom_type.add(atom_type)
    print("Histogram of #nodes")
    print(hist_nodes.bins)
    print("Histogram of # atom types")
    print(hist_atom_type.bins)
    hist_nodes.normalize()


