import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss


from torchdrug import data, utils, core
from Bio.PDB import PDBParser, PPBuilder, PDBIO
from Bio.Data.PDBData import protein_letters_3to1, protein_letters_3to1_extended
# from rdkit import Chem

import pathlib, shutil, os

import torch

from torchdrug import models, layers, transforms

class GearNetEncoder(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.graph_encoder = models.GearNet(
            input_dim = 21,
            hidden_dims = [512, 512, 512, 512, 512, 512],
            batch_norm = True,
            concat_hidden=True,
            short_cut=True,
            readout='sum',
            num_relation=7,
            edge_input_dim= 59,
            num_angle_bin= 8
        )
        self.graph_construction_model = layers.GraphConstruction(
            node_layers = [layers.geometry.AlphaCarbonNode()],
            edge_layers=[layers.geometry.SequentialEdge(2), layers.geometry.SpatialEdge(10.0, 5), layers.geometry.KNNEdge(10, 5)],
            edge_feature='gearnet'
        )
        
    
    def load_pretrain(self, path):
        params_dict = torch.load(path)
        # for key in params_dict:
        #     print(key)
        self.graph_encoder.load_state_dict(params_dict)
        
    def forward(self, batch):
        # batch = self.transform(batch)
        graph = self.graph_construction_model(batch)
        output = self.graph_encoder(graph, graph.node_feature.float())
        return output

class ProteinEmbedding(object):
    def __init__(self, graph_encoder, cache_dir = 'pdb_cache', cache_limit = 1000) -> None:
        '''
        graph_encoder: graph network encoder
        cache_dir: dir for cache pdb file, cache files will be delete if the number of files exceed {cache_limit}
        cache_limit: maximum number of cache file
        '''
        self.parser = PDBParser()
        self.graph_encoder = graph_encoder.cuda()
        self.cache_dir = cache_dir
        self.cache_limit = cache_limit
        
    
    def pdb_process(self, pdb_file, pbd_suffix, data_pipline, **data_pipline_args):
        '''
        pdb_file: pdb file path contain multiple models
        pbd_suffix: the suffix str of pdb file, can be given as a random string
        data_pipline: function which given a PDB file, return obj which could be input to graph network
        data_pipline_args: additional parameters of data_pipline
        
        return: protein features: (N, D)
                    N: the number of models within pdb_file
                    D: the feature (graph) dim of the graph network
        '''
        structure = self.parser.get_structure(pbd_suffix, pdb_file)

        io = PDBIO()
        index = 0
        single_model_file = []
        tmp_path = pathlib.Path(self.cache_dir)
        tmp_path.mkdir(exist_ok=True)
        
        for model in structure:
            io.set_structure(model)
            pdb_path = os.path.join(tmp_path.name, f"{pbd_suffix}_{index}.pdb")
            io.save(pdb_path)
            single_model_file.append(pdb_path)
            index+= 1
        # generate protein data from pdb file which can be input to graph encoder
        batch = data_pipline(single_model_file = single_model_file, **data_pipline_args)

        protein_feature = self.graph_encoder(batch.to(self.graph_encoder.device))
        # print(protein_feature['graph_feature'].shape)
        
        # process cache file, delete pdb file if there are too much file cached
        pdb_cache_count = 0
        for path in tmp_path.iterdir():
            pdb_cache_count += 1
        if pdb_cache_count > 1000:
            print("Too much pdb caches, clean files")
            shutil.rmtree(tmp_path)
        # for file_name in single_model_file:
        #     os.remove(file_name)
        return protein_feature['graph_feature']

def data_pipline(single_model_file, preprocess_fn = None):
    """
    data pipline for gearnet
    """
    batch = []
    for file_name in single_model_file:
        # mol = Chem.MolFromPDBFile(pdb_file)
        protein = data.Protein.from_pdb(file_name)
        if hasattr(protein, "residue_feature"):
            with protein.residue():
                protein.residue_feature = protein.residue_feature.to_dense()
        protein = {'graph': protein}
        protein = preprocess_fn(protein)
        batch.append(protein)
    batch = data.graph_collate(batch)
    return batch['graph']

def optimal_transport_lambda_detect(matrix_1, matrix_2, 
                      method = 'Sinkhorn1', 
                      lambda_value = 0.1, reverse_cost = False,
                      cost_matrix = None, ot_logger = True):
    initial_lambda = lambda_value
    search_done = False
    cur_lambda = initial_lambda
    ot_costs = []
    lambda_list = [initial_lambda]
    final_lambda = None
    while not search_done:
        ot_mapping, ot_cost = optimal_transport(matrix_1, 
                                                matrix_2, 
                                                method=method, 
                                                lambda_value=cur_lambda, 
                                                reverse_cost=reverse_cost,
                                                cost_matrix = cost_matrix)
        if ot_logger:
            print("Search reasonable lambda value", f'times: {len(lambda_list)}, cost: {ot_cost}, lambda: {cur_lambda}')
        if ot_mapping is None:
            if len(ot_costs) == 0:
                cur_lambda = cur_lambda * 10
            else:
                search_done = True
                final_lambda = lambda_list[-2]
        else:
            if len(ot_costs) <= 1:
                cur_lambda -= cur_lambda * 0.1
            else:
                if ot_costs[-1] > ot_cost:
                    cur_lambda -= cur_lambda * 0.1
                else:
                    search_done = True
                    final_lambda = lambda_list[-2]
            
            ot_costs.append(ot_cost)
        lambda_list.append(cur_lambda)
    if ot_logger:
        print("Finally lambda:",final_lambda, lambda_list)
    return optimal_transport(matrix_1, 
                            matrix_2, 
                            method=method, 
                            lambda_value=final_lambda, 
                            reverse_cost=reverse_cost,
                            cost_matrix = cost_matrix)

def optimal_transport(matrix_1, matrix_2, 
                      method = 'Sinkhorn1', 
                      lambda_value = 0.1, reverse_cost = False,
                      cost_matrix = None):
    """
    matrix_1: sample matric of src
    matric_2: sample matric of dist
    method: optimal transport algorithm, 
        Sinkhorn1: the original Sinkhorn algorithm (M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013), 
        Sinkhorn2: the Sinkhorn-Knopp matrix scaling algorithm (M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013)
        Emprirical_Sinkhorn: estimate data distribution via samples.
        EMD: Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December).  Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
    lambda_value: entropy value, Sinkhorn1, Sinkhorn2, and Emprirical_Sinkhorn require this parameter
    """
    assert method in {'Sinkhorn1', 'Sinkhorn2', 'Emprirical_Sinkhorn', 'EMD'}
    
    num_data_1 = len(matrix_1)
    num_data_2 = len(matrix_2)
    if cost_matrix is not None:
        M = cost_matrix
    else:
        M = ot.dist(matrix_1, matrix_2) # cost matrix
    if reverse_cost:
        M = M.max() - M
    # print(M)
    if method == 'Sinkhorn1':
        # assume data follow uniform distribution
        dist_1 = np.ones(num_data_1) / num_data_1
        dist_2 = np.ones(num_data_2) / num_data_2
        # calculate optimal transport plan
        ot_mapping = ot.sinkhorn(dist_1, dist_2, M, lambda_value)
    elif method == 'Sinkhorn2':
        # assume data follow uniform distribution
        dist_1 = np.ones(num_data_1) / num_data_1
        dist_2 = np.ones(num_data_2) / num_data_2
        # calculate optimal transport plan
        ot_mapping = ot.sinkhorn2(dist_1, dist_2, M, lambda_value)
    elif method == 'Sinkhorn_Eps':
        # assume data follow uniform distribution
        dist_1 = np.ones(num_data_1) / num_data_1
        dist_2 = np.ones(num_data_2) / num_data_2
        # calculate optimal transport plan
        ot_mapping = ot.bregman.sinkhorn_epsilon_scaling(dist_1, dist_2, M, lambda_value)
    elif method == "Sinkhorn_Knopp":
        # assume data follow uniform distribution
        dist_1 = np.ones(num_data_1) / num_data_1
        dist_2 = np.ones(num_data_2) / num_data_2
        # calculate optimal transport plan
        ot_mapping = ot.bregman.sinkhorn_knopp(dist_1, dist_2, M, lambda_value)
    elif method == 'Emprirical_Sinkhorn':
        # calculate optimal transport plan directly
        ot_mapping = ot.bregman.empirical_sinkhorn(
            matrix_1, matrix_2, lambda_value, numIterMax=100000
        )
    elif method == 'EMD':
        # assume data follow uniform distribution
        dist_1 = np.ones(num_data_1) / num_data_1
        dist_2 = np.ones(num_data_2) / num_data_2
        ot_mapping = ot.emd(dist_1, dist_2, M)
    else:
        raise NotImplementedError
    # claculate transport cost: int_{X\times Y}c(x,y)d\pi(x,y)
    # assert ot_mapping.sum() < 1.01 and ot_mapping.sum() > 0.99, f"Maybe change a new lambda value, ot mapping sum should close to 1, but the current is: {ot_mapping.sum()}"
    ot_cost = (M * ot_mapping).sum()
    if ot_mapping.sum() > 1.01 or ot_mapping.sum() < 0.99:
        return None, None
    else: 
        return ot_mapping, ot_cost

if __name__ == '__main__':
    # create graph encoder
    graph_encoder = GearNetEncoder()
    # load pretrain parameters
    graph_encoder.load_pretrain('mc_gearnet_edge.pth')
    
    # protein transform, required by gearnet
    transform = transforms.Compose(
            [transforms.ProteinView(view='residue')]
        )
    
    # create protein embedding obj for generating proteni embeddings
    pe = ProteinEmbedding(graph_encoder)
    # the preprocess_fn is additional parameters which require to be given data_pipline.
    pdb_feature_1 = pe.pdb_process('1bq0.pdb', '1bq0', data_pipline, preprocess_fn = transform)
    pdb_feature_2 = pe.pdb_process('1g03.pdb', '1g03', data_pipline, preprocess_fn = transform)
    # the ot cost between two pdb features is too large, do normorlization first
    pdb_feature_1 = torch.nn.functional.normalize(pdb_feature_1, dim=-1)
    pdb_feature_2 = torch.nn.functional.normalize(pdb_feature_2, dim=-1)

    ot_mapping, ot_cost = optimal_transport(pdb_feature_1.detach().cpu().numpy(), pdb_feature_2.detach().cpu().numpy(), method='Sinkhorn1', lambda_value=0.1)
    print(ot_mapping, ot_cost)