import os
from typing import Optional, Sized

import torch
import wandb
from torch.utils.data import Sampler
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_adj
from torch_scatter import scatter_mean
from tqdm import tqdm
from rdkit import Chem
from rdkit.Geometry import Point3D
import pandas as pd
import random
import numpy as np
from numpy.linalg import norm
import data



def output_mol(dataset, task, output_folder, device="cuda", watch_ligand_list=('6kqi', '6hd6', '6qqw'),
               wandb_log_dict=None):
    dataloader = DataLoader(dataset, batch_size=4, follow_batch=['x', 'y', 'compound_pair'],
                            shuffle=False,
                            num_workers=0)
    task.eval()
    for batch_id, batch in enumerate(tqdm(dataloader)):
        batch = batch.to(device)
        mol_coords = task.output_coord(batch)
        # print(mol_coords)
        for i in range(len(mol_coords)):
            predict_coord = mol_coords[i]
            idx = int(batch['idx'][i])
            index = dataset.get_index(idx)
            ligand_original_file_sdf = dataset.get_ligand_file_path_sdf(idx)
            ligand_original_file_mol2 = dataset.get_ligand_file_path_mol2(idx)
            output_file = os.path.join(output_folder, f"predict_{index}_ligand.sdf")
            try:
                mol = write_mol(reference_file=ligand_original_file_sdf, coords=predict_coord, output_file=output_file)
            except:
                mol = write_mol(reference_file=ligand_original_file_mol2, coords=predict_coord, output_file=output_file)

            if index[:4] in watch_ligand_list and wandb_log_dict is not None:
                wandb_log_dict[index] = wandb.Molecule.from_rdkit(mol, convert_to_3d_and_optimize=False)


def tankbind_output_mol(dataset, task, output_folder, device="cuda", use_configuration=True):
    dataloader = DataLoader(dataset, batch_size=3, follow_batch=['x', 'y', 'compound_pair'],
                            shuffle=False,
                            num_workers=0)
    info = []
    task.eval()
    for batch_id, batch in enumerate(tqdm(dataloader)):
        batch = batch.to(device)
        pred_inter_dist_list = task.output_distance_map(batch)
        for i, pred_inter_dist in enumerate(pred_inter_dist_list):
            idx = int(batch['idx'][i])
            index = dataset.get_index(idx)
            ligand_original_file_sdf = dataset.get_ligand_file_path_sdf(idx)
            ligand_original_file_mol2 = dataset.get_ligand_file_path_mol2(idx)
            output_file = os.path.join(output_folder, f"predict_{index}_ligand.sdf")

            predict_coord, loss, rmsd = distance_optimize_compound_coords(
                reference_compound_coords=batch[i]['compound'].true_coords,
                y_pred=pred_inter_dist,
                protein_coords=batch[i]['protein'].coords,
                LAS_edge_index=batch[i][("compound", "LAS", "compound")].edge_index,
                use_configuration=use_configuration
            )
            print(f"RMSD {index}: {rmsd}")

            true_coord = batch[i]['compound'].true_coords.cpu().numpy()
            predict_centroid = predict_coord.mean(axis=0)
            true_centroid = true_coord.mean(axis=0)
            centroid_dis = norm(predict_centroid - true_centroid, 2)
            info.append([index, rmsd, loss, predict_coord, true_coord, predict_centroid, true_centroid, centroid_dis])
            try:
                mol = write_mol(reference_file=ligand_original_file_sdf, coords=predict_coord, output_file=output_file)
            except:
                mol = write_mol(reference_file=ligand_original_file_mol2, coords=predict_coord, output_file=output_file)
    info = pd.DataFrame(info, columns=['index', 'rmsd', 'loss', 'coords', "true_coords", "centroid", "true_centroid", "centroid_dis"])
    info.to_pickle(os.path.join(output_folder, "info.pkl"))
    print("Finshed")


def post_optimize_output_mol(dataset, task, output_folder, device="cuda", rigid=False):
    dataloader = DataLoader(dataset, batch_size=4, follow_batch=['x', 'y', 'compound_pair'],
                            shuffle=False,
                            num_workers=0)
    info = []
    task.eval()
    for batch_id, batch in enumerate(tqdm(dataloader)):
        batch = batch.to(device)
        mol_coords = task.output_coord(batch)
        # print(mol_coords)
        for i in range(len(mol_coords)):
            predict_coord = mol_coords[i].to(device)
            true_coord = batch[i]['compound'].true_coords
            idx = int(batch['idx'][i])
            index = dataset.get_index(idx)
            ligand_original_file_sdf = dataset.get_ligand_file_path_sdf(idx)
            ligand_original_file_mol2 = dataset.get_ligand_file_path_mol2(idx)

            if rigid:
                output_file = os.path.join(output_folder, f"predict_{index}_ligand_fix_rigid.sdf")
                predict_coord, loss, rmsd = post_optimize_compound_coords(
                    reference_compound_coords=true_coord,
                    predict_compound_coords=predict_coord,
                    LAS_edge_index=None,
                )
            else:
                output_file = os.path.join(output_folder, f"predict_{index}_ligand_fix.sdf")
                predict_coord, loss, rmsd = post_optimize_compound_coords(
                    reference_compound_coords=true_coord,
                    predict_compound_coords=predict_coord,
                    LAS_edge_index=batch[i][("compound", "LAS", "compound")].edge_index,
                )

            true_coord = true_coord.cpu().numpy()
            predict_centroid = predict_coord.mean(axis=0)
            true_centroid = true_coord.mean(axis=0)
            centroid_dis = norm(predict_centroid - true_centroid, 2)
            print(f"RMSD {index}: {rmsd}")
            info.append([index, rmsd, loss, predict_coord, true_coord, predict_centroid, true_centroid, centroid_dis])
            try:
                mol = write_mol(reference_file=ligand_original_file_sdf, coords=predict_coord, output_file=output_file)
            except:
                mol = write_mol(reference_file=ligand_original_file_mol2, coords=predict_coord, output_file=output_file)

    info = pd.DataFrame(info, columns=['index', 'rmsd', 'loss', 'coords', "true_coords", "centroid", "true_centroid", "centroid_dis"])
    if rigid:
        info_file = os.path.join(output_folder, "post_optimize_info_rigid.pkl")
    else:
        info_file = os.path.join(output_folder, "post_optimize_info.pkl")
    info.to_pickle(info_file)
    get_statistics(info_file)
    print("Finshed")

def tankbind_output_pocket(dataset, task, output_folder, device="cuda"):

    dataloader = DataLoader(dataset, batch_size=5, follow_batch=['x', 'y', 'compound_pair'],
                            shuffle=False,
                            num_workers=0)
    affinity_info = []
    task.eval()
    for batch_id, batch in enumerate(tqdm(dataloader)):
        batch = batch.to(device)
        affinity_list = task.output_affinity(batch)
        for i, affinity in enumerate(affinity_list):
            idx = int(batch['idx'][i])
            index = dataset.get_index(idx)
            is_equivalent = (dataset.info.iloc[idx].num_contact / dataset.info.iloc[idx].native_num_contact) > 0.9
            true_affinity = batch['affinity'][i]
            affinity_info.append([index, float(affinity), is_equivalent, float(true_affinity)])


    affinity_info = pd.DataFrame(affinity_info, columns=['index', 'affinity', "is_equivalent", "true_affinity"])
    affinity_info.to_pickle(os.path.join(output_folder, "info_affinity.pkl"))

    compound_name = [index[:4] for index in affinity_info['index'].to_list()]
    affinity_info['compound_name'] = compound_name
    max_indices = affinity_info.groupby(['compound_name'])['affinity'].transform(max) == affinity_info['affinity']
    affinity_info_max = affinity_info[max_indices]  # Duplicate: 6jan_0
    print(f"Find Equivalent {affinity_info_max.is_equivalent.sum()}")
    print(f"Find Equivalent Fraction {affinity_info_max.is_equivalent.sum() / len(affinity_info_max.is_equivalent)}")

    print(f"Affinity: {np.sqrt(((affinity_info_max.affinity - affinity_info_max.true_affinity)**2).mean())}")

    print("Finshed")



def end2end_evaluate(dataset, task, output_folder, device="cuda"):
    dataloader = DataLoader(dataset, batch_size=5, follow_batch=['x', 'y', 'compound_pair'],
                            shuffle=False,
                            num_workers=0)
    score_info = []
    task.eval()
    for batch_id, batch in enumerate(tqdm(dataloader)):
        batch = batch.to(device)
        outputs = task.predict(batch, metric={key: None for key in task.criterion.keys()})
        score = outputs['pred_confidence']
        pred_compound_coord = outputs["pred_compound_coord"]
        true_compound_coord = batch["compound"].true_coords
        init_compound_coord = batch["compound"].init_coords
        compound_batch = batch['compound'].batch
        sd = ((pred_compound_coord - true_compound_coord) ** 2).sum(-1)
        init_sd = ((init_compound_coord - true_compound_coord) ** 2).sum(-1)
        msd = scatter_mean(src=sd, index=compound_batch, dim=0)
        init_msd = scatter_mean(src=init_sd, index=compound_batch, dim=0)
        rmsd = torch.sqrt(msd)
        init_rmsd = torch.sqrt(init_msd)

        pred_centroid = scatter_mean(src=pred_compound_coord, index=compound_batch, dim=0)
        true_centroid = scatter_mean(src=true_compound_coord, index=compound_batch, dim=0)
        centroid_dis = (pred_centroid - true_centroid).norm(dim=-1)
        for i in range(len(score)):
            idx = int(batch['idx'][i])
            index = dataset.get_index(idx)
            is_equivalent = (dataset.info.iloc[idx].num_contact / dataset.info.iloc[idx].native_num_contact) > 0.9
            score_i = score[i]
            rmsd_i = rmsd[i]
            centroid_dis_i = centroid_dis[i]
            true_affinity = batch['affinity'][i]
            init_rmsd_i = init_rmsd[i]
            score_info.append([index, float(score_i), is_equivalent, float(rmsd_i), float(centroid_dis_i), float(true_affinity), float(init_rmsd_i)])

    score_info = pd.DataFrame(score_info, columns=['index', 'score', "is_equivalent", "rmsd", "centroid_dis", "true_affinity", "init_rmsd"])
    score_info.to_pickle(os.path.join(output_folder, "info_score.pkl"))

    compound_name = [index[:4] for index in score_info['index'].to_list()]
    score_info['compound_name'] = compound_name
    max_indices = score_info.groupby(['compound_name'])['score'].transform(max) == score_info['score']
    score_info_max = score_info[max_indices]  # Duplicate: 6jan_0
    print(f"Find Equivalent {score_info_max.is_equivalent.sum()}")
    print(f"Find Equivalent Fraction {score_info_max.is_equivalent.sum() / len(score_info_max.is_equivalent)}")
    print(">>>>>>>>")
    rmsd = score_info_max.rmsd
    print(f"Mean RMSD {rmsd.mean()} A")
    print(f"< 5A Percentiles: {(rmsd < 5).sum() / len(rmsd)}")  # 0.20
    print(f"< 2A Percentiles: {(rmsd < 2).sum() / len(rmsd)}")  # 0
    print("Finished")
    print(f"Quantile 25%: {rmsd.quantile(0.25)}")
    print(f"Quantile 50%: {rmsd.quantile(0.5)}")
    print(f"Quantile 75%: {rmsd.quantile(0.75)}")
    print(">>>>>>>>")
    centroid_dis = score_info_max.centroid_dis
    print(f"Mean Centroid Dis {centroid_dis.mean()} A")
    print(f"Mean Centroid Dis < 5A Percentiles: {(centroid_dis < 5).sum() / len(centroid_dis)}")  # 0.20
    print(f"Mean Centroid Dis < 2A Percentiles: {(centroid_dis < 2).sum() / len(centroid_dis)}")  # 0
    print(f"Mean Centroid Dis Quantile 25%: {centroid_dis.quantile(0.25)}")
    print(f"Mean Centroid Dis Quantile 50%: {centroid_dis.quantile(0.5)}")
    print(f"Mean Centroid Dis Quantile 75%: {centroid_dis.quantile(0.75)}")







def write_mol(reference_file, coords, output_file):
    mol = read_molecule(reference_file, sanitize=True, remove_hs=True)
    if mol is None:
        raise Exception()
    conf = mol.GetConformer()
    for i in range(mol.GetNumAtoms()):
        x, y, z = coords[i]
        conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
    w = Chem.SDWriter(output_file)
    w.SetKekulize(False)
    w.write(mol)
    w.close()
    return mol


def post_optimize_loss_function(epoch, x, predict_compound_coords, compound_pair_dis_constraint,
                                LAS_distance_constraint_mask=None, mode=0):
    dis = (x - predict_compound_coords).norm(dim=-1)
    # TODO: clamp large dis?
    dis_clamp = torch.clamp(dis, max=10)
    if mode == 0:
        interaction_loss = dis_clamp.sum()
    elif mode == 1:
        interaction_loss = (dis_clamp ** 2).sum()
    elif mode == 2:
        # probably not a good choice. x^0.5 has infinite gradient at x=0. added 1e-5 for numerical stability.
        interaction_loss = ((dis_clamp.abs() + 1e-5) ** 0.5).sum()
    else:
        raise NotImplementedError()

    config_dis = torch.cdist(x, x)
    if LAS_distance_constraint_mask is not None:
        configuration_loss = 1 * (
            ((config_dis - compound_pair_dis_constraint).abs())[LAS_distance_constraint_mask]).sum()
        # basic exlcuded-volume. the distance between compound atoms should be at least 1.22Å
        configuration_loss += 2 * ((1.22 - config_dis).relu()).sum()
    else:
        configuration_loss = 1 * ((config_dis - compound_pair_dis_constraint).abs()).sum()
    loss = 1 * (interaction_loss + 0.2 * 1e-3 * (epoch) * configuration_loss)  # TODO: fix weight
    return loss, (interaction_loss.item(), configuration_loss.item())


def tankbind_distance_loss_function(epoch, y_pred, x, protein_nodes_xyz, compound_pair_dis_constraint,
                                    use_configuration,
                                    LAS_distance_constraint_mask=None, mode=0):
    dis = torch.cdist(protein_nodes_xyz, x)
    dis_clamp = torch.clamp(dis, max=10)
    if mode == 0:
        interaction_loss = ((dis_clamp - y_pred).abs()).sum()
    elif mode == 1:
        interaction_loss = ((dis_clamp - y_pred) ** 2).sum()
    elif mode == 2:
        # probably not a good choice. x^0.5 has infinite gradient at x=0. added 1e-5 for numerical stability.
        interaction_loss = (((dis_clamp - y_pred).abs() + 1e-5) ** 0.5).sum()
    else:
        raise NotImplementedError()

    if not use_configuration:
        configuration_loss = torch.tensor(0, device=interaction_loss.device)
    else:
        config_dis = torch.cdist(x, x)
        if LAS_distance_constraint_mask is not None:
            configuration_loss = 1 * (
                ((config_dis - compound_pair_dis_constraint).abs())[LAS_distance_constraint_mask]).sum()
            # basic exlcuded-volume. the distance between compound atoms should be at least 1.22Å
            configuration_loss += 2 * ((1.22 - config_dis).relu()).sum()
        else:
            configuration_loss = 1 * ((config_dis - compound_pair_dis_constraint).abs()).sum()

    if epoch < 500:
        loss = interaction_loss
    else:
        loss = 1 * (interaction_loss + 5e-3 * (epoch - 500) * configuration_loss)  # TODO: Annealing
    return loss, (interaction_loss.item(), configuration_loss.item())


def distance_optimize_compound_coords(reference_compound_coords, y_pred, protein_coords, use_configuration,
                                      total_epoch=5000,
                                      LAS_edge_index=None, mode=0):
    if LAS_edge_index is not None:
        LAS_distance_constraint_mask = to_dense_adj(LAS_edge_index).squeeze(0).to(torch.bool)
    else:
        LAS_distance_constraint_mask = None
    # random initialization. center at the protein center.
    compound_pair_dis_constraint = torch.cdist(reference_compound_coords, reference_compound_coords)
    init_center = protein_coords.mean(axis=0)
    x = (5 * (2 * torch.rand(reference_compound_coords.shape, device=reference_compound_coords.device) - 1)
         + init_center.reshape(1, 3).detach())
    x.requires_grad = True
    optimizer = torch.optim.Adam([x], lr=0.1)  # CONFUSE: Adam gradient descent?
    #     optimizer = torch.optim.LBFGS([x], lr=0.01)
    loss_list = []
    rmsd_list = []
    for epoch in range(total_epoch):
        optimizer.zero_grad()
        loss, (interaction_loss, configuration_loss) = tankbind_distance_loss_function(epoch, y_pred, x, protein_coords,
                                                                                       compound_pair_dis_constraint,
                                                                                       LAS_distance_constraint_mask=LAS_distance_constraint_mask,
                                                                                       mode=mode,
                                                                                       use_configuration=use_configuration)
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        rmsd = compute_RMSD(reference_compound_coords, x.detach())  # CONFUSE
        rmsd_list.append(rmsd.item())
        # break
    return x.detach().cpu().numpy(), loss_list[-1], rmsd_list[-1]


def post_optimize_compound_coords(reference_compound_coords, predict_compound_coords,
                                  total_epoch=1000, LAS_edge_index=None, mode=0):
    if LAS_edge_index is not None:
        LAS_distance_constraint_mask = to_dense_adj(LAS_edge_index).squeeze(0).to(torch.bool)
    else:
        LAS_distance_constraint_mask = None
    # random initialization. center at the protein center.
    compound_pair_dis_constraint = torch.cdist(reference_compound_coords, reference_compound_coords)
    x = predict_compound_coords.clone()
    x.requires_grad = True
    optimizer = torch.optim.Adam([x], lr=0.1)
    #     optimizer = torch.optim.LBFGS([x], lr=0.01)
    loss_list = []
    rmsd_list = []
    for epoch in range(total_epoch):
        optimizer.zero_grad()
        loss, (interaction_loss, configuration_loss) = post_optimize_loss_function(epoch, x, predict_compound_coords,
                                                                                   compound_pair_dis_constraint,
                                                                                   LAS_distance_constraint_mask=LAS_distance_constraint_mask,
                                                                                   mode=mode,
                                                                                   )
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        rmsd = compute_RMSD(reference_compound_coords, x.detach()) 
        rmsd_list.append(rmsd.item())
        # break
    return x.detach().cpu().numpy(), loss_list[-1], rmsd_list[-1]




def get_statistics(pkl_file):
    info = pd.read_pickle(pkl_file)

    rmsd = info['rmsd']
    print(f"Mean RMSD {rmsd.mean()} A")
    print(f"RMSD < 5A Percentiles: {(rmsd < 5).sum() / len(rmsd)}")  # 0.20
    print(f"RMSD < 2A Percentiles: {(rmsd < 2).sum() / len(rmsd)}")  # 0
    print(f"RMSD Quantile 25%: {rmsd.quantile(0.25)}")
    print(f"RMSD Quantile 50%: {rmsd.quantile(0.5)}")
    print(f"RMSD Quantile 75%: {rmsd.quantile(0.75)}")

    centroid_dis = info["centroid_dis"]
    print(f"Mean Centroid Dis {centroid_dis.mean()} A")
    print(f"Mean Centroid Dis < 5A Percentiles: {(centroid_dis < 5).sum() / len(centroid_dis)}")  # 0.20
    print(f"Mean Centroid Dis < 2A Percentiles: {(centroid_dis < 2).sum() / len(centroid_dis)}")  # 0
    print(f"Mean Centroid Dis Quantile 25%: {centroid_dis.quantile(0.25)}")
    print(f"Mean Centroid Dis Quantile 50%: {centroid_dis.quantile(0.5)}")
    print(f"Mean Centroid Dis Quantile 75%: {centroid_dis.quantile(0.75)}")




def compute_RMSD(a, b):
    return torch.sqrt((((a - b) ** 2).sum(axis=-1)).mean())


def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):
    # From EquiBind https://github.com/HannesStark/EquiBind/
    """Load a molecule from a file of format ``.mol2`` or ``.sdf`` or ``.pdbqt`` or ``.pdb``.
    Parameters
    ----------
    molecule_file : str
        Path to file for storing a molecule, which can be of format ``.mol2`` or ``.sdf``
        or ``.pdbqt`` or ``.pdb``.
    sanitize : bool
        Whether sanitization is performed in initializing RDKit molecule instances. See
        https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
        Default to False.
    calc_charges : bool
        Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
        ``sanitize`` to be True. Default to False.
    remove_hs : bool
        Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
        slow for large molecules. Default to False.
    use_conformation : bool
        Whether we need to extract molecular conformation from proteins and ligands.
        Default to True.
    Returns
    -------
    mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance for the loaded molecule.
    coordinates : np.ndarray of shape (N, 3) or None
        The 3D coordinates of atoms in the molecule. N for the number of atoms in
        the molecule. None will be returned if ``use_conformation`` is False or
        we failed to get conformation information.
    """
    if molecule_file.endswith('.mol2'):
        mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
    elif molecule_file.endswith('.sdf'):
        supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
        mol = supplier[0]
    else:
        return ValueError('Expect the format of the molecule_file to be '
                          'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))

    try:
        if sanitize:
            Chem.SanitizeMol(mol)
        if remove_hs:
            mol = Chem.RemoveHs(mol, sanitize=sanitize)
    except:
        return None

    return mol


def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def record_data(key, value, step, logger=None, use_wandb=True):
    if use_wandb:
        wandb.log({key: value}, step=step)
    if logger is not None:
        logger.warning(f"{key}: {value}")



def get_eval_task(task):
    if isinstance(task, DistributedDataParallel):
        eval_task = task.module
    else:
        eval_task = task
    return eval_task


def uniform_random_rotation(x):
    """Apply a random rotation in 3D, with a distribution uniform over the
    sphere.
    Arguments:
        x: vector or set of vectors with dimension (n, 3), where n is the
            number of vectors
    Returns:
        Array of shape (n, 3) containing the randomly rotated vectors of x,
        about the mean coordinate of x.
    Algorithm taken from "Fast Random Rotation Matrices" (James Avro, 1992):
    https://doi.org/10.1016/B978-0-08-050755-2.50034-8
    """

    def generate_random_z_axis_rotation():
        """Generate random rotation matrix about the z axis."""
        R = np.eye(3)
        x1 = np.random.rand()
        R[0, 0] = R[1, 1] = np.cos(2 * np.pi * x1)
        R[0, 1] = -np.sin(2 * np.pi * x1)
        R[1, 0] = np.sin(2 * np.pi * x1)
        return R

    # There are two random variables in [0, 1) here (naming is same as paper)
    x2 = 2 * np.pi * np.random.rand()
    x3 = np.random.rand()
    # Rotation of all points around x axis using matrix
    R = generate_random_z_axis_rotation()
    v = np.array([
        np.cos(x2) * np.sqrt(x3),
        np.sin(x2) * np.sqrt(x3),
        np.sqrt(1 - x3)
    ])
    H = np.eye(3) - (2 * np.outer(v, v))
    M = -(H @ R)
    x = x.reshape((-1, 3))
    mean_coord = np.mean(x, axis=0)
    return ((x - mean_coord) @ M) + mean_coord @ M


# For Post_Optimization


# R = 3x3 rotation matrix
# t = 3x1 column vector
# This already takes residue identity into account.


if __name__ == "__main__":
    pass
