# This file has been adapted from PIDM
import torch
import einops

from src.data_utils import generalized_b_xy_c_to_image, generalized_image_to_b_xy_c
from src.grad_utils import GradientsHelper

class ResidualsDarcy:
    def __init__(self, fd_acc, pixels_per_dim, pixels_at_boundary, reverse_d1, device = 'cpu', bcs = 'none', domain_length = 1.,
                 w=0.125, r=10.0):
        """
        Initialize the residual evaluation.

        :param model: The neural network model to compute the residuals for.
        :param n_steps: Number of steps for time discretization.
        :param E: Young's Modulus.
        :param nu: Poisson's Ratio.
        """
        self.gov_eqs = 'darcy'
        self.pixels_at_boundary = pixels_at_boundary
        self.periodic = False
        self.input_dim = 2

        if bcs == 'periodic':
            self.periodic = True

        if self.pixels_at_boundary:
            d0 = domain_length / (pixels_per_dim - 1)
            d1 = domain_length / (pixels_per_dim - 1)
        else:
            d0 = domain_length / pixels_per_dim
            d1 = domain_length / pixels_per_dim
        
        self.reverse_d1 = reverse_d1
        if self.reverse_d1:
            d1 *= -1. # this is for later consistency with visualization

        self.grads = GradientsHelper(d0=d0, d1=d1, fd_acc = fd_acc, periodic=self.periodic, device=device)

        self.pixels_per_dim = pixels_per_dim

        # create stationary source field
        domain_size = 1.
        # create point grid
        self.pixel_size = domain_size / pixels_per_dim
        start = self.pixel_size / 2
        end = domain_size - self.pixel_size / 2
        x = torch.linspace(start, end, steps=pixels_per_dim)
        y = torch.linspace(start, end, steps=pixels_per_dim)
        X, Y = torch.meshgrid(x, y, indexing='ij')
        # compute the function values on the grid
        self.f_s = self.create_f_s(X, Y, w, r).to(device) # [pixels_per_dim, pixels_per_dim]
        self.f_s = generalized_image_to_b_xy_c(self.f_s.unsqueeze(0)).to(device) # [1, pixels_per_dim*pixels_per_dim, 1]

        self.device = device

    # Define the source function using PyTorch operations
    def create_f_s(self, x, y, w = 0.125, r = 10.):
        condition1 = torch.abs(x - 0.5 * w) <= 0.5 * w
        condition2 = torch.abs(x - 1 + 0.5 * w) <= 0.5 * w
        condition3 = torch.abs(y - 0.5 * w) <= 0.5 * w
        condition4 = torch.abs(y - 1 + 0.5 * w) <= 0.5 * w

        result = torch.zeros_like(x)
        result[torch.logical_and(condition1, condition3)] = r
        result[torch.logical_and(condition2, condition4)] = -r
        return result
    
    def compute_residual_direct(self, x0_output, reduce_batch = True, error_fn = torch.abs, use_bcs=True):

        p = x0_output[:, 0]
        permeability_field = x0_output[:, 1]
        p_d0 = self.grads.stencil_gradients(p, mode='d_d0')
        p_d1 = self.grads.stencil_gradients(p, mode='d_d1')
        grad_p = torch.stack([p_d0, p_d1], dim=-3)
        p_d00 = self.grads.stencil_gradients(p, mode='d_d00')
        p_d11 = self.grads.stencil_gradients(p, mode='d_d11')
        perm_d0 = self.grads.stencil_gradients(permeability_field, mode='d_d0')
        perm_d1 = self.grads.stencil_gradients(permeability_field, mode='d_d1')

        velocity_jacobian = torch.zeros_like(x0_output).unsqueeze(-3).repeat(1, 1, self.input_dim, 1, 1)
        velocity_jacobian[:, 0, 0] = -permeability_field * p_d00 - perm_d0 * p_d0
        velocity_jacobian[:, 1, 1] = -permeability_field * p_d11 - perm_d1 * p_d1
        grad_p = generalized_image_to_b_xy_c(grad_p)
        velocity_jacobian = generalized_image_to_b_xy_c(velocity_jacobian)
        
        # obtain equilibrium equations for residual
        eq_0 = velocity_jacobian[:,:, 0, 0] + velocity_jacobian[:, :, 1, 1] - self.f_s
        residual = eq_0

        if use_bcs:
            # manually add BCs
            # reshape output to match image shape
            grad_p_img = generalized_b_xy_c_to_image(grad_p)
            # set up residual for BCs
            residual_bc = torch.zeros_like(grad_p_img, device=grad_p_img.device)
            residual_bc[:,0,0,:] = -grad_p_img[:,0,0,:] # xmin / top (acc. to matplotlib visualization)
            residual_bc[:,0,-1,:] = grad_p_img[:,0,-1,:] # xmax / bot

            if self.reverse_d1:
                residual_bc[:,1,:,0] = grad_p_img[:,1,:,0] # ymin / left
                residual_bc[:,1,:,-1] = -grad_p_img[:,1,:,-1] # ymax / right
            else:
                residual_bc[:,1,:,0] = -grad_p_img[:,1,:,0] # ymin / left
                residual_bc[:,1,:,-1] = grad_p_img[:,1,:,-1] # ymax / right

            residual_bc = generalized_image_to_b_xy_c(residual_bc)
            residual = torch.cat([eq_0.unsqueeze(-1), residual_bc], dim=-1)
        else:
            residual = residual.unsqueeze(1)
        residual = generalized_b_xy_c_to_image(residual)
        if reduce_batch:
            residual = error_fn(residual).mean(dim=tuple(range(1, residual.ndim)))
        
        return residual

    def compute_residual_direct_logK(self, x0_output, reduce_batch = True, error_fn = torch.abs,
                                    logK_norm_c=1., p_norm_c=1., use_bcs=True, rms=False, rms_eps=1e-6): # If using normalized data

        p = x0_output[:, 0]
        logpermeability_field = x0_output[:, 1]
        p_d0 = self.grads.stencil_gradients(p, mode='d_d0')
        p_d1 = self.grads.stencil_gradients(p, mode='d_d1')
        grad_p = torch.stack([p_d0, p_d1], dim=-3)
        p_d00 = self.grads.stencil_gradients(p, mode='d_d00')
        p_d11 = self.grads.stencil_gradients(p, mode='d_d11')
        perm_d0 = self.grads.stencil_gradients(logpermeability_field, mode='d_d0')
        perm_d1 = self.grads.stencil_gradients(logpermeability_field, mode='d_d1')

        velocity_jacobian = torch.zeros_like(x0_output).unsqueeze(-3).repeat(1, 1, self.input_dim, 1, 1)
        velocity_jacobian[:, 0, 0] = -p_d00 - perm_d0 * p_d0 * logK_norm_c
        velocity_jacobian[:, 1, 1] = -p_d11 - perm_d1 * p_d1 * logK_norm_c
        grad_p = generalized_image_to_b_xy_c(grad_p)
        velocity_jacobian = generalized_image_to_b_xy_c(velocity_jacobian)
        
        # obtain equilibrium equations for residual
        eq_0 = velocity_jacobian[:,:, 0, 0] + velocity_jacobian[:, :, 1, 1] - \
                self.f_s*p_norm_c/(torch.exp(einops.rearrange(logpermeability_field*logK_norm_c, "b dim1 dim2 -> b (dim1 dim2)")))
        residual = eq_0

        if use_bcs:
            # manually add BCs
            # reshape output to match image shape
            grad_p_img = generalized_b_xy_c_to_image(grad_p)
            # set up residual for BCs
            residual_bc = torch.zeros_like(grad_p_img, device=grad_p_img.device)
            residual_bc[:,0,0,:] = -grad_p_img[:,0,0,:] # xmin / top (acc. to matplotlib visualization)
            residual_bc[:,0,-1,:] = grad_p_img[:,0,-1,:] # xmax / bot

            if self.reverse_d1:
                residual_bc[:,1,:,0] = grad_p_img[:,1,:,0] # ymin / left
                residual_bc[:,1,:,-1] = -grad_p_img[:,1,:,-1] # ymax / right
            else:
                residual_bc[:,1,:,0] = -grad_p_img[:,1,:,0] # ymin / left
                residual_bc[:,1,:,-1] = grad_p_img[:,1,:,-1] # ymax / right

            residual_bc = generalized_image_to_b_xy_c(residual_bc)
            residual = torch.cat([eq_0.unsqueeze(-1), residual_bc], dim=-1)
        else:
            residual = residual.unsqueeze(1)
        residual = generalized_b_xy_c_to_image(residual)
        if reduce_batch:
            residual = error_fn(residual).mean(dim=tuple(range(1, residual.ndim)))
            if rms:
                residual = torch.sqrt(residual + rms_eps) # Numerical stability for RMS error
        
        return residual
