import math
import copy

from rdkit import Chem
import torch
from tqdm import tqdm

from qm9.sampling import sample
from qm9.rdkit_functions import build_molecule
from qm9.analyze_joint_training import BasicSmilesMetrics, build_2D_mols, smiles_from_2d_mols_list


def sample_3d_molecules(model, nodes_dist, args, device, dataset_info, n_samples=1000, batch_size=100):
    batch_size = min(batch_size, n_samples)
    molecules = {'one_hot': [], 'x': [], 'node_mask': [], 'charges': [], 'edge_mask': []}

    n_batches = math.ceil(n_samples/batch_size) # account for remainder
    for i in tqdm(range(n_batches)):
        if i == n_batches - 1 and n_samples % batch_size != 0:
            n_mols = n_samples % batch_size
        else:
            n_mols = batch_size

        nodesxsample = nodes_dist.sample(n_mols)
        one_hot, charges, x, node_mask, edge_mask = sample(args, device, model, dataset_info, nodesxsample=nodesxsample)

        molecules['one_hot'].append(one_hot.detach().cpu())
        molecules['x'].append(x.detach().cpu())
        molecules['node_mask'].append(node_mask.detach().cpu())
        molecules['charges'].append(charges.detach().cpu())

        max_n_nodes = dataset_info['max_n_nodes']
        edge_mask = edge_mask.view(n_mols, max_n_nodes, max_n_nodes)
        molecules['edge_mask'].append(edge_mask.detach().cpu())

    molecules = {key: torch.cat(molecules[key], dim=0) for key in molecules}
    return molecules


def smiles_from_3d_molecules_with_edge_model(pp_model, molecules_dict, batch_size, dataset_info, return_mols=False):
    molecules_dict = run_pp_model_batchwise(pp_model, molecules_dict, batch_size)
    mol_list = build_2D_mols(molecules_dict, dataset_info)
    if return_mols:
        return mol_list

    # mol_list = postprocess_mols(mol_list, issue='N_keklization')
    # mol_list = postprocess_mols(mol_list, issue='two_ring')
    # mol_list = postprocess_mols(mol_list, issue='valency')

    smiles_list = smiles_from_2d_mols_list(mol_list)

    # Remove disconnected mols
    # smiles_list = [smiles for smiles in smiles_list if '.' not in smiles]
    return smiles_list, mol_list, molecules_dict


def postprocess_mols(mols, issue):
    def is_mol_valid(mol):
        try:
            s = Chem.MolToSmiles(mol)
            m = Chem.MolFromSmiles(s)
            return s != '' and m is not None and m.GetNumAtoms() > 0
        except:
            return False

    postprocessed_mols = []
    for mol in mols:
        if is_mol_valid(mol):
            postprocessed_mols.append(copy.deepcopy(mol))
        else:
            mol_copy = copy.deepcopy(mol)

            if issue == 'N_keklization':
                for atom in mol_copy.GetAtoms():
                    if atom.GetSymbol() == 'N' and atom.GetIsAromatic() and atom.GetDegree() < 3:
                        atom.SetNumExplicitHs(1)
                        break
            elif issue == 'two_ring':
                mol_copy = fix_two_ring_problem(mol_copy)
            elif issue == 'valency':
                mol_copy = fix_valency(mol_copy)

            if is_mol_valid(mol_copy):
                postprocessed_mols.append(copy.deepcopy(mol_copy))
            else:
                postprocessed_mols.append(copy.deepcopy(mol))

    return postprocessed_mols


def fix_two_ring_problem(mol):
    Chem.rdmolops.GetSSSR(mol)
    rings = mol.GetRingInfo().BondRings()

    chosen_ring = None
    for i in range(len(rings)):
        for j in range(i+1, len(rings)):
            if len(set(rings[i] + rings[j])) < len(rings[i]) + len(rings[j]):
                if len(rings[i]) < len(rings[j]):
                    chosen_ring = set(rings[i]).difference(rings[j])
                else:
                    chosen_ring = set(rings[j]).difference(rings[i]) 

    if chosen_ring is None:
        return mol
    chosen_bonds = [(mol.GetBondWithIdx(b).GetBeginAtom().GetIdx(), mol.GetBondWithIdx(b).GetEndAtom().GetIdx()) for b in chosen_ring]

    for begin, end in chosen_bonds:
        mol.RemoveBond(begin, end)
        mol.AddBond(begin, end, Chem.rdchem.BondType.SINGLE)
    return mol


def fix_valency(mol):
    mol.UpdatePropertyCache(strict=False)
    allowed_valences = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3,
                     'Si': 4, 'P': 5,
                     'S': 6, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2],
                     'Bi': [3, 5]}

    problematic_atoms = []
    for atom in mol.GetAtoms():
        valence = atom.GetTotalValence() - atom.GetFormalCharge()
        allowed_valence = allowed_valences[atom.GetSymbol()]
        if valence > allowed_valence:
            if atom.GetFormalCharge() < 0:
                atom.SetFormalCharge(0)
            else:
                problematic_atoms.append(atom)

    if len(problematic_atoms) == 0:
        return mol
    
    def reduce_bond_level(mol, bond):
        begin = bond.GetBeginAtom().GetIdx()
        end = bond.GetEndAtom().GetIdx()
        mol.RemoveBond(begin, end)

        bond_type = bond.GetBondType()
        new_bond_type = None
        if bond_type == Chem.rdchem.BondType.TRIPLE:
            new_bond_type = Chem.rdchem.BondType.DOUBLE
        elif bond_type == Chem.rdchem.BondType.DOUBLE:
            new_bond_type = Chem.rdchem.BondType.SINGLE
        elif bond_type == Chem.rdchem.BondType.SINGLE:
            new_bond_type = None

        if new_bond_type is not None:
            mol.AddBond(begin, end, new_bond_type)
            
    def reduce_bond_level_for_atom(atom):
        highest_order = 0
        highest_order_bond = None
        for bond in atom.GetBonds():
            if bond.GetBondTypeAsDouble() > highest_order:
                highest_order = bond.GetBondTypeAsDouble()
                highest_order_bond = bond
        reduce_bond_level(mol, highest_order_bond)


    if len(problematic_atoms) == 2:
        atom1 = problematic_atoms[0]
        atom2 = problematic_atoms[1]
        bond = mol.GetBondBetweenAtoms(atom1.GetIdx(), atom2.GetIdx())
        if bond is not None:
            reduce_bond_level(mol, bond)
        else:
            reduce_bond_level_for_atom(atom1)
            reduce_bond_level_for_atom(atom2)
    elif len(problematic_atoms) == 1:
        atom = problematic_atoms[0]
        reduce_bond_level_for_atom(atom)
    
    return mol


def run_pp_model_batchwise(pp_model, molecules_dict, batch_size):
    """
    Computes the adjacency matrix and the atom types and charges preds of a list of molecules 
    using the EGNN edge model batch-wise.
    molecules_dict will be modified in place by creating a new key 'adj_matrices'
    Args:
        edge_model (EGNNEdgeModel)
        molecules_dict (dict): contains the prediction of the diffusion model. 
                               Has keys: 'x', 'node_mask', 'one_hot', 'charges'
        batch_size (int): batch size for the edge_model inference

    Returns:
        molecules_dict (dict): input dict containing a new key 'adj_matrices'
    """
    n_samples = len(molecules_dict['x'])
    n_batches = math.ceil(n_samples/batch_size) # account for remainder
    adj_matrices = []
    atom_types = []
    formal_charges = []
    for i in range(n_batches):
        start_idx = i*batch_size
        if i == n_batches - 1 and n_samples % batch_size != 0:
            end_idx = start_idx + n_samples%batch_size
            current_batch_size = n_samples%batch_size
        else:
            end_idx = start_idx + batch_size
            current_batch_size = batch_size

        # prepare batch for edge_model prediction
        x = molecules_dict['x'][start_idx:end_idx]
        node_mask = molecules_dict['node_mask'][start_idx:end_idx]
        one_hot = molecules_dict['one_hot'][start_idx:end_idx]
        charges = molecules_dict['charges'][start_idx:end_idx]
        edge_mask = molecules_dict['edge_mask'][start_idx:end_idx]

        adj_pred, atom_types_pred, formal_charges_pred = pp_model.map_to_2d(one_hot, charges, x, node_mask, edge_mask)

        adj_matrices.append(adj_pred.detach().cpu())
        atom_types.append(atom_types_pred.detach().cpu())
        formal_charges.append(formal_charges_pred.detach().cpu())


    adj_matrices = torch.cat(adj_matrices, dim=0)
    molecules_dict['adjacency_matrices'] = adj_matrices

    atom_types = torch.cat(atom_types, dim=0)
    molecules_dict['atom_types'] = atom_types

    formal_charges = torch.cat(formal_charges, dim=0)
    molecules_dict['formal_charges'] = formal_charges

    molecules_dict['positions'] = molecules_dict['x']
    return molecules_dict


def get_adj_matrices_with_edge_model(edge_model, molecules_dict, batch_size):
    """
    Computes the adjacency matrix of a list of molecules using the EGNN edge model batch-wise
    molecules_dict will be modified in place by creating a new key 'adj_matrices'
    Args:
        edge_model (EGNNEdgeModel)
        molecules_dict (dict): contains the prediction of the diffusion model. 
                               Has keys: 'x', 'node_mask', 'one_hot', 'charges'
        batch_size (int): batch size for the edge_model inference

    Returns:
        molecules_dict (dict): input dict containing a new key 'adj_matrices'
    """
    n_samples = len(molecules_dict['x'])
    n_batches = math.ceil(n_samples/batch_size) # account for remainder
    adj_matrices = []
    atom_types = []
    formal_charges = []
    for i in range(n_batches):
        start_idx = i*batch_size
        if i == n_batches - 1 and n_samples % batch_size != 0:
            end_idx = start_idx + n_samples%batch_size
            current_batch_size = n_samples%batch_size
        else:
            end_idx = start_idx + batch_size
            current_batch_size = batch_size

        # prepare batch for edge_model prediction
        batch = {}
        batch['positions'] = molecules_dict['x'][start_idx:end_idx]
        batch['atom_mask'] = molecules_dict['node_mask'][start_idx:end_idx].squeeze()
        batch['one_hot'] = molecules_dict['one_hot'][start_idx:end_idx]
        batch['charges'] = molecules_dict['charges'][start_idx:end_idx]
        
        #Obtain edges
        _, n_nodes, n_atom_types = batch['one_hot'].size()
        edge_mask = batch['atom_mask'].unsqueeze(1) * batch['atom_mask'].unsqueeze(2)

        #mask diagonal
        diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
        edge_mask *= diag_mask

        batch['edge_mask'] = edge_mask.view(current_batch_size * n_nodes * n_nodes, 1)

        # Run inference
        adj_pred, h_pred = edge_model(batch)

        # Get predicted class labels
        adj_pred = torch.argmax(adj_pred, -1)
        adj_matrices.append(adj_pred)

        atomic_numbers = torch.Tensor(range(n_atom_types)).unsqueeze(0).unsqueeze(0).to(h_pred.device)
        atom_type_pred = (h_pred[:, :, :n_atom_types].argmax(-1).unsqueeze(-1) == atomic_numbers).float()
        atom_types.append(atom_type_pred)

        formal_charge_pred = h_pred[:, :, n_atom_types:].argmax(-1).unsqueeze(-1) - 1
        formal_charges.append(formal_charge_pred)


    adj_matrices = torch.cat(adj_matrices, dim=0)
    molecules_dict['adj_matrices'] = adj_matrices

    atom_types = torch.cat(atom_types, dim=0)
    assert molecules_dict['one_hot'].shape == atom_types.shape
    molecules_dict['one_hot'] = atom_types

    formal_charges = torch.cat(formal_charges, dim=0)
    assert molecules_dict['charges'].shape == formal_charges.shape
    molecules_dict['charges'] = formal_charges
    return molecules_dict

bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, 
                    Chem.rdchem.BondType.AROMATIC]

# def build_2D_mols(molecules_dict, dataset_info):
#     """
#     Creates a list of RDKit Mol objects from the generated atoms as well as the predicted strucutres using the edge model
#     Args:
#         molecules_dict (dict): contains the prediction of the diffusion model and the edge model. 
#                                Has keys: 'x', 'node_mask', 'one_hot', 'charges', 'adj_matrices'
#         dataset_info (dict)

#     Returns:
#         mol_list (list): list of 2D molecules
#     """
#     one_hot = molecules_dict['one_hot']
#     x = molecules_dict['x']
#     node_mask = molecules_dict['node_mask']
#     charges = molecules_dict['charges']
#     adj_matrices = molecules_dict['adj_matrices']
#     atom_decoder = dataset_info["atom_decoder"]
#     n_samples = len(x)

#     if isinstance(node_mask, torch.Tensor):
#         atomsxmol = torch.sum(node_mask, dim=1)
#     else:
#         atomsxmol = [torch.sum(m) for m in node_mask]

#     mol_list = []
#     for i in range(n_samples):
#         try:
#             atom_type = one_hot[i].argmax(1).cpu().detach()
#             pos = x[i].cpu().detach()
#             charge = charges[i].cpu().detach()
#             adj_matrix = adj_matrices[i].cpu().detach()

#             atom_type = atom_type[0:int(atomsxmol[i])]
#             pos = pos[0:int(atomsxmol[i])]
#             charge = charge[0:int(atomsxmol[i])]

#             mol = Chem.RWMol()
#             for atom, charge in zip(atom_type, charge):
#                 a = Chem.Atom(atom_decoder[atom.item()])
#                 a.SetFormalCharge(int(charge.item()))
#                 mol.AddAtom(a)
                
#             # because adj_matrix is symmetric, we'll only look at the upper diagonal
#             # to avoid processing each edge twice
#             all_bonds = torch.nonzero(torch.triu(adj_matrix))
#             for bond in all_bonds:
#                 mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[adj_matrix[bond[0], bond[1]].item()])

#             mol_list.append(mol)
#         except:
#             continue

#     return mol_list


def smiles_from_3d_molecules(molecule_list, dataset_info):
    one_hot = molecule_list['one_hot']
    x = molecule_list['x']
    node_mask = molecule_list['node_mask']
    charges = molecule_list['charges']

    if isinstance(node_mask, torch.Tensor):
        atomsxmol = torch.sum(node_mask, dim=1)
    else:
        atomsxmol = [torch.sum(m) for m in node_mask]

    n_samples = len(x)
    processed_list = []
    for i in range(n_samples):
        atom_type = one_hot[i].argmax(1).cpu().detach()
        pos = x[i].cpu().detach()
        charge = charges[i].cpu().detach()

        atom_type = atom_type[0:int(atomsxmol[i])]
        pos = pos[0:int(atomsxmol[i])]
        charge = charge[0:int(atomsxmol[i])]
        processed_list.append((pos, atom_type, charge))

    smiles_list = []
    for mol_3d in processed_list:
        mol = build_molecule(*mol_3d, dataset_info)
        try:
            #mol = Chem.RemoveHs(mol)
            #Chem.SanitizeMol(mol)
            smiles = Chem.MolToSmiles(mol)
        except:
            smiles = 'invalid mol'
        
        smiles_list.append(smiles)

    return smiles_list
