import numpy as np
import torch
from torch.autograd.functional import jacobian
from linear_operator.operators import DiagLinearOperator

import sys
sys.path.insert(0,'../util')
from sample_points import sampled_pts_grid


class Eikonal(object):
    """
    |\nabla u(x)|^2 = f(x)^2 + \epsilon \Delta u(x), x \in \Omega, epsilon = 0.1
    u(x) = 0, x \in \partial\Omega
    """
    def __init__(self, eps=0.1, bdy=None, rhs=None, domain=np.array([[0, 1], [0, 1]])):
        self.eps = eps
        self.bdy = bdy
        self.rhs = rhs
        self.domain = domain
        self.dim = self.domain.shape[0]

    def sampled_pts(self, N_domain, N_boundary, sampled_type='grid'):
        Xd, Xb = sampled_pts_grid(N_domain, N_boundary, self.domain, time_dependent=False)
        if sampled_type == 'random':
            Xd += torch.randn(*Xd.shape) * 1e-2
        elif sampled_type == 'grid':
            pass
        else:
            raise ValueError('sampled_type should be grid or random')
        self.X_domain = Xd
        self.X_boundary = Xb
        self.Nd, self.Nb = self.X_domain.shape[0], self.X_boundary.shape[0]
        self.rhs_f = self.rhs(self.X_domain)
        self.bdy_g = self.bdy(self.X_boundary)
        
    def loss(self, mu, power=2.0):
        lhs_u = self.lhs_f(mu)
        e_u = lhs_u - torch.cat([self.rhs_f, self.bdy_g],-1)  # PDE residual
        r = e_u.pow(2).sum(-1)
        res = .5* r.pow(power/2.0)
        if power!=2: res += .5*(1-power/2) * (self.Nd+self.Nb) * r.log()
        return res.mean(-1), e_u[...,:self.Nd]
    
    def lhs_f(self, u_):
        u_d, u_b, du, Lap_u = self.extract_solution(u_)[0]
        lhs = torch.cat([du.square().sum(-1) - self.eps*Lap_u,u_b],-1)
        return lhs

    def d_lhs(self, u_):
        u_.requires_grad = True
        return jacobian(self.lhs_f, u_)
    
    def linearization(self, u_):
        fun = self.lhs_f(u_)
        jac = self.d_lhs(u_)
        return fun, jac
    
    def propagate_distribution(self, mu, covar, u0=None, linrz = None, interleaved=True, eps=1e-6, diag=True):
        if u0 is None: u0 = mu
        if linrz is None: linrz = self.linearization(u0)
        fun, jac = linrz
        mu_ = fun + (jac*(mu-u0)).sum(dim=tuple(range(-mu.ndim, 0)))
        if not interleaved: jac = jac.permute((0,)+tuple(range(-1,-jac.ndim, -1)))
        jac = jac.reshape(jac.shape[0],-1)
        if not diag:
            covar_ = jac.matmul(covar).matmul(jac.transpose(-1,-2)) + eps*torch.eye(jac.shape[0], device=covar.device)
        else:
            var_ = (jac.matmul(covar)*jac).sum(-1) + eps
            covar_ = DiagLinearOperator(var_)
        return mu_, covar_
    
    def extract_solution(self, mu, var=None):
        u_d, u_b = mu[...,:self.Nd,0], mu[...,self.Nd:,0]
        du = mu[...,:self.Nd,1:1+self.dim]
        Lap_u = mu[...,:self.Nd,1+self.dim:].sum(-1)
        if var is None: var = torch.ones_like(mu)
        v_d, v_b = var[...,:self.Nd,0], var[...,self.Nd:,0]
        dv = var[...,:self.Nd,1:1+self.dim]
        Lap_v = var[...,:self.Nd,1+self.dim:].sum(-1)
        return [u_d, u_b, du, Lap_u], [v_d, v_b, dv, Lap_v]

    def plot_solution(self, u, ax=None, **kwargs):
        import matplotlib.pyplot as plt
        x, y = self.X_domain[:,0], self.X_domain[:,1]
        N_pts = int(np.sqrt(self.Nd+self.Nb))-2
        x = x.reshape((N_pts,)*2)
        y = y.reshape((N_pts,)*2)
        if ax is None:
            ax = plt.gca()
        ctf = ax.contourf(x, y , u, **kwargs)
        clb = plt.colorbar(ctf, ax=ax)
        clb.ax.tick_params(labelsize = ax.yaxis.get_tick_params()['labelsize'])