import numpy as np

from src.pdedata import PDEDataBase


class InversePDEData(PDEDataBase):
    """
    Prepare data for the inverse problems.
    """
    def __init__(self) -> None:
        super().__init__()
        del self.ref_sol
        self.ref_sol_u = None
        self.ref_sol_a = None
        self.P_inv = None

    def setup(self, config: dict):
        self.config = config
        self.init_problem()
        self.prepare_test_data()
        self.prepare_train_data()
        self.prepare_boundary_data()

    def prepare_test_data(self):
        dim_s_t = self.config["Spatial Temporal Dimension"]
        bbox = self.config["Bounding Box"]
        if self.ref_sol_u is None or self.ref_sol_a is None:
            raise ValueError("ref_sol_u and ref_sol_a must be set.")
        if dim_s_t == 2:
            sample_points = 50
        elif dim_s_t == 3:
            sample_points = 30
        else:
            sample_points = int(1e4**(1 / dim_s_t))
        xlist = [np.linspace(bbox[i * 2], bbox[i * 2 + 1], sample_points) 
            for i in range(dim_s_t)]
        self.test_X = np.stack(np.meshgrid(*xlist), axis=-1).reshape(-1, dim_s_t)
        self.test_Y = np.concatenate(
            [self.ref_sol_u(self.test_X), self.ref_sol_a(self.test_X)], axis=-1)

    def prepare_train_data(self):
        '''
        Prepare training data.
        '''
        raise NotImplementedError("prepare_train_data() is not implemented.")
    
    def prepare_boundary_data(self):
        '''
        Prepare boundary data.
        '''
        raise NotImplementedError("prepare_boundary_data() is not implemented.")

    def get_train_data(self):
        """
        Return the train data.
        """
        return self.train_X, self.train_Y
    
    def get_boundary_data(self):
        """
        Return the boundary data.
        """
        return self.train_X_b, self.train_Y_b
    
    def compute_loss_grad(self):
        """
        Compute the loss and gradients.
        """
        raise NotImplementedError("compute_loss_grad() is not implemented.")

    def set_u_data(self, u):
        """
        Set the u data.
        """
        self.u_data = u.reshape(-1).detach().cpu().numpy()
    
    def set_a_data(self, a):
        """
        Set the a data.
        """
        self.a_data = a.reshape(-1).detach().cpu().numpy()
