import numpy as np
import cvxpy as cp
import torch
import gpytorch
from rdkit.Chem import AllChem, Descriptors, MolFromSmiles

def print_list(list):
    print("List length: ", len(list))
    for x in list:
        print(x)



# Helper functions

def get_bigmij(vi, vj, W):
    """
    Compute M(i,j) for designs i and j 
    :param vi, vj: (D,1) ndarrays
    :param W: (n_constraint,D) ndarray
    :return: M(i,j).
    """
    D = W.shape[1]
    P = 2*np.eye(D)
    q = (-2*(vj-vi)).ravel()
    G = -W
    h = -np.array([np.max([0,np.dot(W[0,:],vj-vi)[0]]),
                np.max([0,np.dot(W[1,:],vj-vi)[0]])])

    # Define and solve the CVXPY problem.
    x = cp.Variable(D)
    prob = cp.Problem(cp.Minimize((1/2)*cp.quad_form(x, P) + q.T @ x),
                 [G @ x <= h])
    #A @ x == b    
    prob.solve()
    bigmij = np.sqrt(prob.value + np.dot((vj-vi).T, vj-vi)).ravel()

    # Print result.
    #print("\nThe optimal value is", prob.value)
    #print("A solution x is")
    #print(x.value)
    #print("A dual solution corresponding to the inequality constraints is")
    #print(prob.constraints[0].dual_value)
    #print("M(i,j) is", bigmij)
    return bigmij


def get_alpha(rind, W):
    """
    Compute alpha_rind for row rind of W 
    :param rind: row index
    :param W: (n_constraint,D) ndarray
    :return: alpha_rind.
    """
    m = W.shape[0]+1 #number of constraints
    D = W.shape[1]
    f = -W[rind,:]
    A = []
    b = []
    c = []
    d = []
    for i in range(W.shape[0]):
        A.append(np.zeros((1, D)))
        b.append(np.zeros(1))
        c.append(W[i,:])
        d.append(np.zeros(1))
    
    A.append(np.eye(D))
    b.append(np.zeros(D))
    c.append(np.zeros(D))
    d.append(np.ones(1))

    # Define and solve the CVXPY problem.
    x = cp.Variable(D)
    # We use cp.SOC(t, x) to create the SOC constraint ||x||_2 <= t.
    soc_constraints = [
          cp.SOC(c[i].T @ x + d[i], A[i] @ x + b[i]) for i in range(m)
    ]
    prob = cp.Problem(cp.Minimize(f.T@x),
                  soc_constraints)
    prob.solve()

    """
    # Print result.
    print("The optimal value is", -prob.value)
    print("A solution x is")
    print(x.value)
    for i in range(m):
        print("SOC constraint %i dual variable solution" % i)
        print(soc_constraints[i].dual_value)
    """    
        
    return -prob.value   


def get_alpha_vec(W):
    """
    Compute alpha_vec for W 
    :param W: an (n_constraint,D) ndarray
    :return: alpha_vec, an (n_constraint,1) ndarray
    """    
    alpha_vec = np.zeros((W.shape[0],1))
    for i in range(W.shape[0]):
        alpha_vec[i] = get_alpha(i, W)
    return alpha_vec


def get_smallmij(vi, vj, W, alpha_vec):
    """
    Compute m(i,j) for designs i and j 
    :param vi, vj: (D,1) ndarrays
    :param W: (n_constraint,D) ndarray
    :param alpha_vec: (n_constraint,1) ndarray of alphas of W
    :return: m(i,j).
    """    
    prod = np.matmul(W, vj - vi)
    prod[prod<0] = 0
    smallmij = (prod/alpha_vec).min()
    
    return smallmij  


def is_covered_SOCP(vi, vj, eps, W):
    """
    Check if vi is eps covered by vj for cone matrix W 
    :param vi, vj: (D,1) ndarrays
    :param W: An (n_constraint,D) ndarray
    :param eps: float
    :return: Boolean.
    """    
    m = 2*W.shape[0]+1 # number of constraints
    D = W.shape[1]
    f = np.zeros(D)
    A = []
    b = []
    c = []
    d = []

    for i in range(W.shape[0]):
        A.append(np.zeros((1, D)))
        b.append(np.zeros(1))
        c.append(W[i,:])
        d.append(np.zeros(1))
    
    A.append(np.eye(D))
    b.append((vi-vj).ravel())
    c.append(np.zeros(D))
    d.append(eps*np.ones(1))

    for i in range(W.shape[0]):
        A.append(np.zeros((1, D)))
        b.append(np.zeros(1))
        c.append(W[i,:])
        d.append(np.dot(W[i,:],(vi-vj)))
        
    # Define and solve the CVXPY problem.
    x = cp.Variable(D)
    # We use cp.SOC(t, x) to create the SOC constraint ||x||_2 <= t.
    soc_constraints = [
          cp.SOC(c[i].T @ x + d[i], A[i] @ x + b[i]) for i in range(m)
    ]
    prob = cp.Problem(cp.Minimize(f.T@x),
                  soc_constraints)
    prob.solve()

    """
    # Print result.
    print("The optimal value is", prob.value)
    print("A solution x is")
    print(x.value)
    print(x.value is not None)
    for i in range(m):
        print("SOC constraint %i dual variable solution" % i)
        print(soc_constraints[i].dual_value)
    """     
    return x.value is not None


def is_covered(vi, vj, eps, W):
    """
    Check if vi is eps covered by vj for cone matrix W 
    :param vi, vj: (D,1) ndarrays
    :param W: An (n_constraint,D) ndarray
    :param eps: float
    :return: Boolean.
    """  
    #if np.dot((vi-vj).T, vi-vj) <= eps**2:
    #    return True
    return is_covered_SOCP(vi, vj, eps, W)

    
def get_pareto_set(mu, W, alpha_vec, return_mask = False):
    """
    Find the indices of Pareto designs (rows of mu)
    :param mu: An (n_points, D) array
    :param W: (n_constraint,D) ndarray
    :param alpha_vec: (n_constraint,1) ndarray of alphas of W
    :param return_mask: True to return a mask
    :return: An array of indices of pareto-efficient points.
        If return_mask is True, this will be an (n_points, ) boolean array
        Otherwise it will be a (n_efficient_points, ) integer array of indices.
    """
    is_efficient = np.arange(mu.shape[0])
    n_points = mu.shape[0]
    next_point_index = 0  # Next index in the is_efficient array to search for
    while next_point_index<len(mu):
        nondominated_point_mask = np.zeros(mu.shape[0], dtype=bool)
        vj = mu[next_point_index].reshape(-1,1)
        for i in range(len(mu)):
            vi = mu[i].reshape(-1,1)
            nondominated_point_mask[i] =  (get_smallmij(vi, vj, W, alpha_vec) == 0) and (get_bigmij(vi, vj, W) > 0)
        nondominated_point_mask[next_point_index] = True
        is_efficient = is_efficient[nondominated_point_mask]  # Remove dominated points
        mu = mu[nondominated_point_mask]
        next_point_index = np.sum(nondominated_point_mask[:next_point_index])+1
    if return_mask:
        is_efficient_mask = np.zeros(n_points, dtype = bool)
        is_efficient_mask[is_efficient] = True
        return is_efficient_mask
    else:
        return is_efficient 

    
def get_delta(mu, W, alpha_vec):
    """
    Computes Delta^*_i for each i in [n.points]
    :param mu: An (n_points, D) array
    :param W: (n_constraint,D) ndarray
    :param alpha_vec: (n_constraint,1) ndarray of alphas of W
    :return: An (n_points, D) array of Delta^*_i for each i in [n.points]
    """
    n = mu.shape[0]
    Delta = np.zeros(n)
    for i in range(n):
        for j in range(n):
            vi = mu[i,:].reshape(-1,1)
            vj = mu[j,:].reshape(-1,1)
            mij = get_smallmij(vi, vj, W, alpha_vec)
            if mij>Delta[i]:
                Delta[i] = mij
    
    return Delta.reshape(-1,1)


def get_uncovered_set(p_opt_miss, p_opt_hat, mu, eps, W):
    """
    Check if vi is eps covered by vj for cone matrix W 
    :param p_opt_hat: ndarray of indices of designs in returned Pareto set
    :param p_opt_miss: ndarray of indices of Pareto optimal points not in p_opt_hat
    :mu: An (n_points,D) mean reward matrix
    :param eps: float
    :param W: An (n_constraint,D) ndarray
    :return: ndarray of indices of points in p_opt_miss that are not epsilon covered
    """  
    uncovered_set = []
    
    for i in p_opt_miss:
        uncovered = True
        for j in p_opt_hat:
            if is_covered(mu[i,:].reshape(-1,1), mu[j,:].reshape(-1,1), eps, W):
                uncovered = False
                break
        
        if uncovered:
            uncovered_set.append(i)
        
    return np.array(uncovered_set)


class CategoricalKernel(gpytorch.kernels.Kernel):
    r"""A Kernel for categorical features.

    Computes `exp(-dist(x1, x2) / lengthscale)`, where
    `dist(x1, x2)` is zero if `x1 == x2` and one if `x1 != x2`.
    If the last dimension is not a batch dimension, then the
    mean is considered.

    Note: This kernel is NOT differentiable w.r.t. the inputs.
    """

    has_lengthscale = True

    def forward(
        self,
        x1: torch.Tensor,
        x2: torch.Tensor,
        diag: bool = False,
        last_dim_is_batch: bool = False,
        **kwargs,
    ) -> torch.Tensor:
        delta = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        dists = delta / (self.lengthscale.unsqueeze(-2)*1)
        if last_dim_is_batch:
            dists = dists.transpose(-3, -1)
        else:
            dists = dists.mean(-1)
        res = torch.exp(-dists)
        if diag:
            res = torch.diagonal(res, dim1=-1, dim2=-2)
        return res


class CategoricalKernel_actv(gpytorch.kernels.Kernel):
    r"""A Kernel for categorical features.

    Computes `exp(-dist(x1, x2) / lengthscale)`, where
    `dist(x1, x2)` is zero if `x1 == x2` and one if `x1 != x2`.
    If the last dimension is not a batch dimension, then the
    mean is considered.

    Note: This kernel is NOT differentiable w.r.t. the inputs.
    """

    has_lengthscale = True

    def forward(
        self,
        x1: torch.Tensor,
        x2: torch.Tensor,
        diag: bool = False,
        last_dim_is_batch: bool = False,
        active_dims = None,
        **kwargs,
    ) -> torch.Tensor:
        if active_dims is None:
            delta = x1.unsqueeze(-2) != x2.unsqueeze(-3)
        else:
            x1[active_dims].unsqueeze(-2) != x2[active_dims].unsqueeze(-3)
        dists = delta / (self.lengthscale.unsqueeze(-2)*1)
        if last_dim_is_batch:
            dists = dists.transpose(-3, -1)
        else:
            dists = dists.mean(-1)
        res = torch.exp(-dists)
        if diag:
            res = torch.diagonal(res, dim1=-1, dim2=-2)
        return res



class TanimotoKernel(gpytorch.kernels.Kernel):
    def __init__(self, variance=1.0):
        super().__init__()
        self.register_parameter(
            name='variance',
            parameter=torch.nn.Parameter(torch.tensor(variance, dtype=torch.double))
        )

    def forward(self, x1, x2, diag=False,last_dim_is_batch=False):
        if x2 is None:
            x2 = x1

        x1s = x1.pow(2).sum(dim=-1)
        x2s = x2.pow(2).sum(dim=-1)
        cross_product = torch.tensordot(x1, x2, [[-1], [-1]]) 

        denominator = -cross_product + (x1s.unsqueeze(-1) + x2s)
        output = self.variance * cross_product / (denominator)

        if diag:
            return output.diag()
        return output

def featurise_mols(smiles_list, representation, bond_radius=3, nBits=2048):
    """
    Featurise molecules according to representation
    :param smiles_list: list of molecule SMILES
    :param representation: str giving the representation. Can be 'fingerprints' or 'fragments'.
    :param bond_radius: int giving the bond radius for Morgan fingerprints. Default is 3
    :param nBits: int giving the bit vector length for Morgan fingerprints. Default is 2048
    :return: X, the featurised molecules
    """

    if representation == 'fingerprints':

        rdkit_mols = [MolFromSmiles(smiles) for smiles in smiles_list]
        X = [AllChem.GetMorganFingerprintAsBitVect(mol, bond_radius, nBits=nBits) for mol in rdkit_mols]
        X = np.asarray(X)

    elif representation == 'fragments':

        # descList[115:] contains fragment-based features only
        # (https://www.rdkit.org/docs/source/rdkit.Chem.Fragments.html)

        fragments = {d[0]: d[1] for d in Descriptors.descList[115:]}
        X = np.zeros((len(smiles_list), len(fragments)))
        for i in range(len(smiles_list)):
            mol = MolFromSmiles(smiles_list[i])
            try:
                features = [fragments[d](mol) for d in fragments]
            except:
                raise Exception('molecule {}'.format(i) + ' is not canonicalised')
            X[i, :] = features

    elif representation == 'fragprints':

        rdkit_mols = [MolFromSmiles(smiles) for smiles in smiles_list]
        X = [AllChem.GetMorganFingerprintAsBitVect(mol, 3, nBits=2048) for mol in rdkit_mols]
        X = np.asarray(X)

        fragments = {d[0]: d[1] for d in Descriptors.descList[115:]}
        X1 = np.zeros((len(smiles_list), len(fragments)))
        for i in range(len(smiles_list)):
            mol = MolFromSmiles(smiles_list[i])
            try:
                features = [fragments[d](mol) for d in fragments]
            except:
                raise Exception('molecule {}'.format(i) + ' is not canonicalised')
            X1[i, :] = features

        X = np.concatenate((X, X1), axis=1)

    else:

        # SMILES

        return smiles_list

    return X



class PyTorchSSK(gpytorch.kernels.Kernel):
    def __init__(self, n, lbda):
        super(PyTorchSSK, self).__init__()
        self.n = n
        self.lbda = lbda
    
    def string_kernel_single(self, x1, x2):
        len1, len2 = len(x1), len(x2)
        k_prim = torch.zeros((self.n, len1, len2))
        k_prim[0, :, :] = 1.0

        for i in range(1, self.n):
            for sj in range(i, len1):
                toret = 0.
                for tk in range(i, len2):
                    if x1[sj-1] == x2[tk-1]:
                        toret = self.lbda * (toret + self.lbda * k_prim[i-1, sj-1, tk-1])
                    else:
                        toret *= self.lbda
                    k_prim[i, sj, tk] = toret + self.lbda * k_prim[i, sj-1, tk]
        
        k = 0.
        for i in range(self.n):
            for sj in range(i, len1):
                for tk in range(i, len2):
                    if x1[sj] == x2[tk]:
                        k += self.lbda * self.lbda * k_prim[i, sj, tk]

        return k
    
    def forward(self, x1, x2, **params):
        batch_shape = x1.shape[:-2]
        num_strings_x1, num_strings_x2 = x1.size(0), x2.size(0)

        result = torch.zeros(num_strings_x1, num_strings_x2)
        
        diag_x1 = torch.zeros(num_strings_x1)
        diag_x2 = torch.zeros(num_strings_x2)
        
        for i in range(num_strings_x1):
            for j in range(num_strings_x2):
                x1_str = x1[i][x1[i] != 0].long()  # Remove padding zeros
                x2_str = x2[j][x2[j] != 0].long()  # Remove padding zeros
                result[i, j] = self.string_kernel_single(x1_str, x2_str)
                
        for i in range(num_strings_x1):
            x1_str = x1[i][x1[i] != 0].long()  # Remove padding zeros
            diag_x1[i] = self.string_kernel_single(x1_str, x1_str)
        
        for j in range(num_strings_x2):
            x2_str = x2[j][x2[j] != 0].long()  # Remove padding zeros
            diag_x2[j] = self.string_kernel_single(x2_str, x2_str)
        
        # Normalizing the kernel matrix
        norm_factor = torch.sqrt(diag_x1[:, None] * diag_x2[None, :])
        result /= norm_factor

        return result.view(*batch_shape, num_strings_x1, num_strings_x2)