import torch
from models.gmm import TimeDependentGMM
from models.potential_2d import Box, Slit, HarmonicOscillator, Hill
from models.potential import Cubic
from sklearn.preprocessing import StandardScaler
import numpy as np
import pickle
import pdb

POTENTIAL_NAME = {
    "box" : Box,
    "slit": Slit,
    "harmonic_oscillator": HarmonicOscillator,
    "hill": Hill,
    "Cubic": Cubic
}

def lambda_t(t, lms, intervals):
    for i, (t0, t1) in enumerate(intervals):
        if t0 <= t < t1:
            return lms[i]
    if t == t1:
        return lms[-1]
    
    raise ValueError
    
class PotentialFreeLagrangian:
    def __init__(self):
        device = 'cuda'

        phate_fn = '../data/scRNAseq_embryoid_body/phate.pickle'
        with open(phate_fn, 'rb') as fb:
            phate = pickle.load(fb)
        
        scaler = StandardScaler()
        scaler.fit(phate.graph.data_nu)
        # M is of shape (5, gene number)
        self.M = torch.tensor(phate.graph.data_pca.components_[:5,:] * np.expand_dims(scaler.scale_[:5], axis=1), dtype=torch.float32, device=device)

        idx_fn = '../data/scRNAseq_embryoid_body/idx_gene_de.npy'
        self.idx_gene_de = torch.tensor(np.load(idx_fn), dtype=torch.int64, device=device)

        ppi_fn = '../data/scRNAseq_embryoid_body/idxs_ppi_400.npy'
        self.idx_ppi = torch.tensor(np.load(ppi_fn), dtype=torch.int64, device=device)
        self.mseloss = torch.nn.MSELoss()

    def L(self, t, x, u):
        return 0.5 * torch.sum(torch.pow(u, 2), 1, keepdims=True)

    def inv_L(self, t, x, f):
        return f
    
    def dM_potential(self, x):
        cov = torch.einsum('bd,be->de', x, x) / x.shape[0] - torch.mean(x, dim=0, keepdims=True).t() @ torch.mean(x, dim=0, keepdims=True)
        cov = (self.M.t() @ cov @ self.M)[self.idx_gene_de][:, self.idx_gene_de]

        # normalize covariance matrix, range [-1, 1]
        cov = cov / torch.sqrt( torch.clamp(torch.diag(cov).unsqueeze(dim=1), min=1e-10) ) / torch.sqrt( torch.clamp(torch.diag(cov).unsqueeze(dim=0), min=1e-10))

        # label matrix
        mat_label = - torch.eye(cov.shape[0]).to(cov.device)
        mat_label[self.idx_ppi[:,0], self.idx_ppi[:,1]] = 1

        return ( self.mseloss(cov[mat_label==1], torch.ones_like(cov[mat_label==1])) + \
                 self.mseloss(cov[mat_label==0], torch.zeros_like(cov[mat_label==0])) ) / 2
    
    def dM(self, x, u):
        fdim = x.shape[1]
        idx_aragnge = torch.arange(fdim).to(x.device)

        # moment regularizetion
        dcov_dt = torch.mean(torch.einsum('bd,be->dbe', u, x), dim=1) + torch.mean(torch.einsum('bd,be->dbe', x, u), dim=1) \
            - torch.mean(u, dim=0, keepdims=True).t() @ torch.mean(x, dim=0, keepdims=True) \
            - torch.mean(x, dim=0, keepdims=True).t() @ torch.mean(u, dim=0, keepdims=True)
        dcov_dt = (self.M.t() @ dcov_dt @ self.M)[self.idx_gene_de][:, self.idx_gene_de]
        
        dcov_dt = torch.pow(dcov_dt, 2)
        return dcov_dt

class CellularLagrangian:
    def __init__(self, Xs, ts, n_components_list, lm_u2=1.0, lm_U=1.0, lm_v=1.0, intervals=None, device='cpu'):
        if intervals is None:
            t_set = list(sorted(list(set(ts.numpy()))))
            intervals = []
            for i in range(len(t_set) - 1):
                intervals.append((t_set[i], t_set[i+1]))
        self.U = TimeDependentGMM(Xs, ts, n_components_list=n_components_list, intervals=intervals).to(device)
        self.intervals = intervals
        self.lm_u2_value = lm_u2
        self.lm_U_value = lm_U
        self.lm_v_value = lm_v
    
    def lm_u2(self, t):
        if not type(self.lm_u2_value) is list:
            return self.lm_u2_value
        else:
            return lambda_t(t, self.lm_u2_value, self.intervals)
    
    def lm_U(self, t):
        if not type(self.lm_U_value) is list:
            return self.lm_U_value
        else:
            return lambda_t(t, self.lm_U_value, self.intervals)
    
    def lm_v(self, t):
        if not type(self.lm_v_value) is list:
            return self.lm_v_value
        else:
            return lambda_t(t, self.lm_v_value, self.intervals)

    def L(self, t, x, u, v=None):
        Uxt = self.U(x, t).unsqueeze(1).float()

        if v is None:
            return self.lm_u2(t) * 0.5 * torch.sum(torch.pow(u, 2), 1, keepdims=True) - self.lm_U(t) * Uxt
        else:
            return self.lm_u2(t) * 0.5 * torch.sum(torch.pow(u, 2), 1, keepdims=True) - self.lm_U(t) * Uxt \
                                + self.lm_v(t) * 0.5 * torch.sum(torch.pow(u - v, 2), 1, keepdims=True)

    def inv_L(self, t, x, f, v=None):
        if v is None:
            return f
        else:
            return (self.lm_v(t) * v + f) / (1 + self.lm_v(t))


class CellularLagrangian_moment(CellularLagrangian):
    def __init__(self, Xs, ts, n_components_list, lm_u2=1.0, lm_U=1.0, lm_v=1.0, intervals=None, device='cpu'):
        super(CellularLagrangian_moment, self).__init__(Xs=Xs, ts=ts, n_components_list=n_components_list, lm_u2=lm_u2, lm_U=lm_U, lm_v=lm_v, intervals=intervals, device=device)

        phate_fn = '../data/scRNAseq_embryoid_body/phate.pickle'
        with open(phate_fn, 'rb') as fb:
            phate = pickle.load(fb)
        
        scaler = StandardScaler()
        scaler.fit(phate.graph.data_nu)
        # M is of shape (5, gene number)
        self.M = torch.tensor(phate.graph.data_pca.components_[:5,:] * np.expand_dims(scaler.scale_[:5], axis=1), dtype=torch.float32, device=device)

        idx_fn = '../data/scRNAseq_embryoid_body/idx_gene_de.npy'
        self.idx_gene_de = torch.tensor(np.load(idx_fn), dtype=torch.int64, device=device)

    # def L(self, t, x, u, v=None):
    #     # x, u, v are of shape (batch size, 5)
    #     Uxt = self.U(x, t).unsqueeze(1).float()

    #     # moment regularizetion
    #     dcov_dt = torch.mean(torch.einsum('bd,be->dbe', u, x), dim=1) + torch.mean(torch.einsum('bd,be->dbe', x, u), dim=1) \
    #         - torch.mean(u, dim=0, keepdims=True).t() @ torch.mean(x, dim=0, keepdims=True) \
    #         - torch.mean(x, dim=0, keepdims=True).t() @ torch.mean(u, dim=0, keepdims=True)
    #     dcov_dt = (self.M.t() @ dcov_dt @ self.M)[self.idx_gene_de][:, self.idx_gene_de]
    #     dcov_dt = torch.pow(dcov_dt, 2)
    #     loss_moreg = torch.mean(dcov_dt[dcov_dt>=0.1]).expand((x.shape[0])).unsqueeze(dim=1)

    #     if v is None:
    #         return self.lm_u2(t) * 0.5 * torch.sum(torch.pow(u, 2), 1, keepdims=True) - self.lm_U(t) * Uxt + self.lm_M * loss_moreg
    #     else:
    #         return self.lm_u2(t) * 0.5 * torch.sum(torch.pow(u, 2), 1, keepdims=True) - self.lm_U(t) * Uxt \
    #                             + self.lm_v(t) * 0.5 * torch.sum(torch.pow(u - v, 2), 1, keepdims=True) + self.lm_M * loss_moreg
    
    def dM(self, x, u):
        fdim = x.shape[1]
        idx_aragnge = torch.arange(fdim).to(x.device)

        # moment regularizetion
        dcov_dt = torch.mean(torch.einsum('bd,be->dbe', u, x), dim=1) + torch.mean(torch.einsum('bd,be->dbe', x, u), dim=1) \
            - torch.mean(u, dim=0, keepdims=True).t() @ torch.mean(x, dim=0, keepdims=True) \
            - torch.mean(x, dim=0, keepdims=True).t() @ torch.mean(u, dim=0, keepdims=True)
        dcov_dt = (self.M.t() @ dcov_dt @ self.M)[self.idx_gene_de][:, self.idx_gene_de]
        
        dcov_dt = torch.pow(dcov_dt, 2)
        # dm = torch.zeros((x.shape[0], 1)).to(x.device)
        # dm[0] += torch.mean(dcov_dt)
        # return dm
        return dcov_dt


class NewtonianLagrangian:
    def __init__(self, M, lm_u2=1.0, lm_U=1.0, U_cfg=None, U=None):
        self.M = M
        if U is None:
            self.U = POTENTIAL_NAME[U_cfg["name"]]()
        else:
            self.U = U
        self.lm_u2 = lm_u2
        self.lm_U = lm_U
    
    def L(self, t, x, u):
        Ux = self.U(x, t).float()
        return self.lm_u2 * 0.5 * self.M * torch.sum(torch.pow(u, 2), dim=1, keepdims=True) - self.lm_U * Ux

    def inv_L(self, t, x, f):
        return (1 / self.M) * f
    
class NullLagrangian:
    def L(self, t, x, u):
        return torch.zeros((len(u), 1)).to(x)

    def inv_L(self, t, x, f):
        return f
    