import random

from torch import nn
from torch_scatter import scatter_mean
from torchdrug import tasks, core

import torch


class End2EndDocking(tasks.Task, core.Configurable):
    def __init__(self, model, criterion, verbose=0, max_iter_num=8,
                 inter_distance_threshold=10, intra_distance_threshold=8):
        super(End2EndDocking, self).__init__()
        self.model = model
        self.criterion = criterion
        self.verbose = verbose
        self.max_iter_num = max_iter_num
        self.inter_distance_threshold = inter_distance_threshold
        self.intra_distance_threshold =intra_distance_threshold

        self.confidence_pos_threshold = 0.6
        self.confidence_neg_threshold = 0.4
        self.confidence_rmsd_threshold_1 = 10
        self.confidence_rmsd_threshold_2 = 15

    def set_loss_weight(self, key, value):
        if not key in self.criterion.keys():
            raise NotImplementedError()
        else:
            self.criterion[key] = value

    def forward(self, batch):
        """"""
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {key: None for key in self.criterion.keys()}

        pred = self.predict(batch, all_loss, metric)
        target = self.target(batch)

        pred_compound_coord = pred["pred_compound_coord"]
        pred_inter_dist = pred["pred_inter_dist"]
        pred_intra_dist = pred["pred_intra_dist"]
        pred_LAS_dist = pred["pred_LAS_dist"]

        if "Affinity Loss" in self.criterion.keys():
            pred_affinity = pred["pred_affinity"]
        if "Confidence Loss" in self.criterion.keys():
            pred_confidence = pred["pred_confidence"]
        true_compound_coord, true_inter_dist, true_intra_dist, true_LAS_dist, true_affinity = target

        coord_sd = ((pred_compound_coord.detach() - true_compound_coord) ** 2).sum(axis=-1)
        coord_rmsd = scatter_mean(coord_sd, index=batch['compound'].batch, dim=0).sqrt().detach()
        pocket_loss_mask = torch.logical_or(coord_rmsd < 8, batch.is_equivalent_native_pocket)
        coord_loss_mask = pocket_loss_mask[batch['compound'].batch]
        for criterion, weight in self.criterion.items():
            if criterion == "Coordinate MSD":
                # Coordinate Loss
                if coord_loss_mask.sum() == 0:
                    msd = torch.tensor(0, requires_grad=True, dtype=all_loss.dtype, device=all_loss.device)
                else:
                    msd = (((pred_compound_coord[coord_loss_mask] -
                         true_compound_coord[coord_loss_mask]) ** 2).sum(axis=-1)).mean()
                metric["Coordinate MSD"] = msd.clone()
                all_loss += weight * msd
            elif criterion == "Inter Distance MSE":
                se = (pred_inter_dist - true_inter_dist) ** 2
                out_range_mask = true_inter_dist > self.inter_distance_threshold
                se[out_range_mask] = torch.clamp(self.inter_distance_threshold - pred_inter_dist[out_range_mask], min=0) ** 2
                metric["Inter Distance MSE"] = se.mean()
                all_loss += se.mean() * weight
            elif criterion == "Intra Distance MSE":
                se = (pred_intra_dist - true_intra_dist) ** 2
                out_range_mask = true_intra_dist > self.intra_distance_threshold
                se[out_range_mask] = torch.clamp(self.intra_distance_threshold - pred_intra_dist[out_range_mask], min=0) ** 2
                metric["Intra Distance MSE"] = se.mean()
                all_loss += se.mean() * weight
            elif criterion == "LAS Distance MSE":
                se = (pred_LAS_dist - true_LAS_dist) ** 2
                metric["LAS Distance MSE"] = se.mean()
                all_loss += se.mean() * weight
                all_loss += se.mean() * weight
            elif criterion == "Affinity Loss":
                decoy_gap = 0
                affinity_loss = torch.zeros(pred_affinity.shape).to(pred_affinity.device)
                affinity_loss[batch.is_equivalent_native_pocket] = \
                    ((pred_affinity - true_affinity) ** 2)[batch.is_equivalent_native_pocket]
                affinity_loss[~batch.is_equivalent_native_pocket] = \
                    (((pred_affinity - (true_affinity - decoy_gap)).relu()) ** 2)[~batch.is_equivalent_native_pocket]
                metric["Affinity Loss"] = affinity_loss.mean()
                all_loss += affinity_loss.mean() * weight
            elif criterion == "Confidence Loss":
                undecoy_mask = coord_rmsd < self.confidence_rmsd_threshold_1
                decoy_mask_1 = torch.logical_and(coord_rmsd >= self.confidence_rmsd_threshold_1, coord_rmsd < self.confidence_rmsd_threshold_2)
                decoy_mask_2 = coord_rmsd > self.confidence_rmsd_threshold_2

                true_confidence = 1 - (1-self.confidence_pos_threshold) * coord_rmsd / self.confidence_rmsd_threshold_1
                confidence_loss = torch.zeros(pred_confidence.shape).to(pred_confidence.device)
                confidence_loss[undecoy_mask] = \
                    ((true_confidence - pred_confidence) ** 2)[undecoy_mask]
                confidence_loss[decoy_mask_1] = \
                    ((pred_confidence - self.confidence_pos_threshold).relu() ** 2)[decoy_mask_1]
                confidence_loss[decoy_mask_2] = \
                    ((pred_confidence - self.confidence_neg_threshold).relu() ** 2)[decoy_mask_2]

                metric["Confidence Loss"] = confidence_loss.mean()
                all_loss += confidence_loss.mean() * weight
            else:
                raise ValueError("Unknown criterion `%s`" % criterion)

        return all_loss, metric


    def get_distance_info(self, compound_coord, protein_coord, batch):
        # TODO: use sparse operation
        pdist = nn.PairwiseDistance(p=2)
        interactive_distance_list = []
        internal_distance_list = []
        LAS_distance_list = []
        for i in range(len(batch["protein"].ptr)-1):
            compound_ptr = batch['compound'].ptr
            protein_prt = batch['protein'].ptr
            compound_coord_i = compound_coord[compound_ptr[i]:compound_ptr[i + 1]]
            protein_coord_i = protein_coord[protein_prt[i]:protein_prt[i + 1]]
            inter_dist = torch.cdist(protein_coord_i, compound_coord_i).reshape(-1)  # TODO
            intra_dist = torch.cdist(compound_coord_i, compound_coord_i).reshape(-1)
            LAS_src = batch[i][("compound", "LAS", "compound")]['edge_index'][0]
            LAS_dst = batch[i][("compound", "LAS", "compound")]['edge_index'][1]
            LAS_dist = (compound_coord_i[LAS_src] - compound_coord_i[LAS_dst]).norm(dim=-1)
            interactive_distance_list.append(inter_dist)
            internal_distance_list.append(intra_dist)
            LAS_distance_list.append(LAS_dist)

        inter_dist = torch.cat(interactive_distance_list, dim=0)
        intra_dist = torch.cat(internal_distance_list, dim=0)
        LAS_dist = torch.cat(LAS_distance_list, dim=0)

        return inter_dist, intra_dist, LAS_dist


    def predict(self, batch, all_loss=None, metric=None):
        if self.training:
            iter_i = random.randint(1, self.max_iter_num)
        else:
            iter_i = self.max_iter_num
        outputs = {}
        model_outputs = self.model(batch, iter_i=iter_i, metric=metric)

        pred_inter_dist, pred_intra_dist, pred_LAS_dist = self.get_distance_info(compound_coord=model_outputs["compound_node_coord"],
                                                                                 protein_coord=batch["protein"].coords,
                                                                                 batch=batch)
        outputs["pred_inter_dist"] = pred_inter_dist
        outputs["pred_intra_dist"] = pred_intra_dist
        outputs["pred_LAS_dist"] = pred_LAS_dist
        outputs["pred_compound_coord"] = model_outputs["compound_node_coord"]
        if metric is not None and "Affinity Loss" in metric.keys():
            outputs["pred_affinity"] = model_outputs["affinity"]
        if metric is not None and "Confidence Loss" in metric.keys():
            outputs["pred_confidence"] = model_outputs["confidence"]
        return outputs

    def target(self, batch):
        true_compound_coord = batch["compound"].true_coords
        protein_coord = batch["protein"].coords
        true_affinity = batch["affinity"]

        true_inter_dist, true_intra_dist, true_LAS_dist = self.get_distance_info(true_compound_coord, protein_coord, batch)
        return true_compound_coord, true_inter_dist, true_intra_dist, true_LAS_dist, true_affinity

    def evaluate_metric(self, batch, eval_metrics=('rmsd', 'rmsd < 2A', 'rmsd < 5A')):
        result_dict = {}
        pred = self.predict(batch, metric={key: None for key in self.criterion.keys()})
        pred_compound_coord = pred["pred_compound_coord"]
        pred_affinity = pred["pred_affinity"]

        true_compound_coord, true_inter_dist, true_intra_dis, true_LAS_dist, true_affinity = self.target(batch)
        compound_batch = batch['compound'].batch
        sd = ((pred_compound_coord - true_compound_coord) ** 2).sum(-1)
        msd = scatter_mean(src=sd, index=compound_batch, dim=0)
        rmsd = torch.sqrt(msd)

        pred_centroid = scatter_mean(src=pred_compound_coord, index=compound_batch, dim=0)
        true_centroid = scatter_mean(src=true_compound_coord, index=compound_batch, dim=0)
        centroid_dis = (pred_centroid - true_centroid).norm(dim=-1)

        affinity_mse = (pred_affinity - true_affinity) ** 2

        for metric in eval_metrics:
            if metric == 'rmsd':
                result_dict['rmsd'] = rmsd
            elif metric == 'rmsd < 2A':
                result_dict['rmsd < 2A'] = (rmsd < 2).to(torch.float)
            elif metric == 'rmsd < 5A':
                result_dict['rmsd < 5A'] = (rmsd < 5).to(torch.float)
            elif metric == 'centroid dis':
                result_dict["centroid dis"] = centroid_dis
            elif metric == 'centroid dis < 2A':
                result_dict['centroid dis < 2A'] = (centroid_dis < 2).to(torch.float)
            elif metric == 'centroid dis < 5A':
                result_dict['centroid dis < 5A'] = (centroid_dis < 5).to(torch.float)
            elif metric == "affinity mse":
                result_dict["affinity mse"] = affinity_mse
            elif metric == "confidence precision":
                pred_confidence = pred["pred_confidence"]
                undecoy_mask = rmsd < self.confidence_rmsd_threshold_1
                confidence_precision = torch.zeros(pred_confidence.shape).to(pred_confidence.device)
                confidence_precision[undecoy_mask] = (pred_confidence > self.confidence_pos_threshold)[undecoy_mask].to(torch.float)
                confidence_precision[~undecoy_mask] = (pred_confidence < self.confidence_pos_threshold)[~undecoy_mask].to(torch.float)
                result_dict["confidence precision"] = confidence_precision
            elif metric == "confidence":
                pred_confidence = pred["pred_confidence"]
                result_dict["confidence"] = pred_confidence

        return result_dict

    def output_coord(self, batch, iter_i=None):
        if iter_i is None:
            iter_i = self.max_iter_num
        else:
            iter_i = iter_i
        model_outputs = self.model(batch, iter_i=iter_i, metric=None)
        pred_compound_coord = model_outputs["compound_node_coord"]
        batch_size = len(batch["protein"].ptr)-1
        mol_coords = []
        for i in range(batch_size):
            i_mask = (batch['compound'].batch == i)
            mol_coords.append(pred_compound_coord[i_mask].detach().cpu())   # TODO
        return mol_coords

    def output_trajectory(self, batch, iter_i=None):
        if iter_i is None:
            iter_i = self.max_iter_num
        else:
            iter_i = iter_i
        model_outputs = self.model(batch, iter_i=iter_i, metric=None, print_trajectory=True)
        trajectory_list = model_outputs["trajectory_list"]
        trajectory_dict = {key: [] for key in batch.pdb}
        for trajctory in trajectory_list:
            for i_batch, pdb in enumerate(batch.pdb):
                i_mask = (batch['compound'].batch == i_batch)
                trajectory_dict[pdb].append(trajctory[i_mask].detach())

        return trajectory_dict

    # def output_affinity(self, batch, iter_i=None):
    #     if iter_i is None:
    #         iter_i = self.max_iter_num
    #     else:
    #         iter_i = iter_i
    #     model_outputs = self.model(batch, iter_i=iter_i, metric={key: None for key in self.criterion.keys()})
    #     affinity = model_outputs["confidence"]
    #     return affinity.detach().cpu()








