import deepxde as dde
import numpy as np

DEFAULT_NUM_DOMAIN_POINTS = 8192
DEFAULT_NUM_BOUNDARY_POINTS = 2048
DEFAULT_NUM_TEST_POINTS = 8192
DEFAULT_NUM_INITIAL_POINTS = 2048


class BasePDE():
    def __init__(self):
        self.pde = None
        self.bcs = None
        self.geom = None
        self.bbox = None
        self.loss_config = []
        self.output_config = None

        self.num_domain_points = DEFAULT_NUM_DOMAIN_POINTS
        self.num_boundary_points = DEFAULT_NUM_BOUNDARY_POINTS
        self.num_test_points = DEFAULT_NUM_TEST_POINTS

        self.ref_sol = None
        self.ref_data = None

        self.recommend_net = None

    @property
    def input_dim(self):
        return self.geom.dim

    @property
    def output_dim(self):
        if self.output_config is None:
            raise ValueError("output_config not set")
        return len(self.output_config)

    @output_dim.setter
    def output_dim(self, value):
        if self.output_config is None:
            self.output_config = [{'name': f'y_{i+1}'} for i in range(value)]
        else:
            assert self.output_dim == value, "output_config and output_dim not matched"

    @property
    def num_pde(self):
        return sum(map((lambda c: c['type'] == 'pde'), self.loss_config))

    @property
    def num_gepinn(self):
        return sum(map((lambda c: c['type'] == 'gepinn'), self.loss_config))

    @property
    def num_boundary(self):
        return sum(map((lambda c: c['type'] == 'boundary'), self.loss_config))

    @property
    def num_loss(self):
        return len(self.loss_config)

    def trans_time_data_to_dataset(self, datapath):
        data = self.ref_data
        slice = (data.shape[1] - self.input_dim + 1) // self.output_dim
        assert slice * self.output_dim == data.shape[1] - self.input_dim + 1, "Data shape is not multiple of pde.output_dim"
        
        with open(datapath, "r") as f:
            def extract_time(string):
                index = string.find("t=")
                if index == -1:
                    return None
                return float(string[index+2:].split(' ')[0])
            t = None
            for line in f.readlines():
                if line.startswith('%') and line.count('@') == slice * self.output_dim:
                    t = line.split('@')[1:]
                    t = list(map(extract_time, t))
            if t is None or None in t: 
                raise ValueError("Reference Data not in Comsol format or does not contain time info")
            t = np.array(t[::self.output_dim])

        t, x0 = np.meshgrid(t, data[:, 0])
        list_x = [x0.reshape(-1)]
        for i in range(1, self.input_dim - 1):
            list_x.append(np.stack([data[:, i] for _ in range(slice)]).T.reshape(-1))
        list_x.append(t.reshape(-1))
        for i in range(self.output_dim):
            list_x.append(data[:, self.input_dim - 1 + i::self.output_dim].reshape(-1))
        self.ref_data = np.stack(list_x).T

    def load_ref_data(self, datapath, transform_fn=None, t_transpose=False):
        self.ref_data = np.loadtxt(datapath, comments="%").astype(np.float32)
        if t_transpose: 
            self.trans_time_data_to_dataset(datapath)
        if transform_fn is not None:
            self.ref_data = transform_fn(self.ref_data)

    def set_pdeloss(self, names=None, num=1):
        if names is not None:
            self.loss_config += [{"name": name, "type": 'pde'} for name in names]
        else:
            self.loss_config += [{"name": f"pde_{i}", "type": 'pde'} for i in range(num)]

    def add_bcs(self, config, geom=None):
        geom = geom if geom is not None else self.geom

        if self.bcs is None:
            self.bcs = []
        for bc in config:
            if bc.get('name') is None:
                bc['name'] = bc['type'] + ('' if bc['type'] == 'ic' else 'bc') + f"_{len(self.bcs) + 1}"
            if bc['type'] == 'dirichlet':
                self.bcs.append(dde.DirichletBC(geom, bc['function'], bc['bc'], component=bc['component']))
            elif bc['type'] == 'robin':
                self.bcs.append(dde.RobinBC(geom, bc['function'], bc['bc'], component=bc['component']))
            elif bc['type'] == 'ic':
                self.bcs.append(dde.IC(geom, bc['function'], bc['bc'], component=bc['component']))
            elif bc['type'] == 'operator':
                self.bcs.append(dde.OperatorBC(geom, bc['function'], bc['bc']))
            elif bc['type'] == 'neumann':
                self.bcs.append(dde.NeumannBC(geom, bc['function'], bc['bc'], component=bc['component']))
            elif bc['type'] == 'periodic':
                self.bcs.append(dde.PeriodicBC(geom, bc['component_x'], bc['bc'], component=bc['component']))
            elif bc['type'] == 'pointset':
                self.bcs.append(dde.PointSetBC(bc['points'], bc['values'], component=bc['component']))
            else:
                raise ValueError(f"Unknown bc type: {bc['type']}")
            self.loss_config.append({'name': bc['name'], 'type': 'boundary'})

    def training_points(self, domain=DEFAULT_NUM_DOMAIN_POINTS, boundary=DEFAULT_NUM_BOUNDARY_POINTS, test=DEFAULT_NUM_TEST_POINTS, mul=1):
        self.num_domain_points = domain * mul
        self.num_boundary_points = boundary * mul
        self.num_test_points = test * mul

    def check(self):
        if self.pde is None:
            raise ValueError("PDE could not be None")
        if self.geom is None:
            raise ValueError("geometry could not be None")
        if self.output_config is None:
            raise ValueError("output config could not be None, please set output dim or output config")
        if self.bbox is None:
            raise ValueError("bbox could not be None")
        if self.num_pde == 0:
            raise ValueError("No pde loss specified")
        for i in range(self.num_pde):
            if self.loss_config[i]['type'] != 'pde':
                raise ValueError("All PDE loss should be set before Boundary loss to avoid potential issues with methods like NTK")

    def create_model(self, net):
        self.check()
        self.net = net
        self.data = dde.data.PDE(
            self.geom,
            self.pde,
            self.bcs,
            num_domain=self.num_domain_points,
            num_boundary=self.num_boundary_points,
            num_test=self.num_test_points,
        )
        self.model = dde.Model(self.data, net)
        self.model.pde = self
        return self.model


class BaseTimePDE(BasePDE):
    def __init__(self):
        super().__init__()
        self.geomtime = None
        self.num_initial_points = DEFAULT_NUM_INITIAL_POINTS

    @property
    def input_dim(self):
        return self.geomtime.dim

    def add_bcs(self, config):
        super().add_bcs(config, self.geomtime)

    def load_ref_data(self, datapath, transform_fn=None, t_transpose=True):
        super(BaseTimePDE, self).load_ref_data(datapath, transform_fn, t_transpose)

    def training_points(
        self,
        domain=DEFAULT_NUM_DOMAIN_POINTS,
        boundary=DEFAULT_NUM_BOUNDARY_POINTS,
        initial=DEFAULT_NUM_INITIAL_POINTS,
        test=DEFAULT_NUM_TEST_POINTS,
        mul=1,
    ):
        self.num_domain_points = domain * mul
        self.num_boundary_points = boundary * mul
        self.num_initial_points = initial * mul
        self.num_test_points = test * mul

    def create_model(self, net):
        self.check()
        self.net = net
        self.data = dde.data.TimePDE(
            self.geomtime,
            self.pde,
            self.bcs,
            num_domain=self.num_domain_points,
            num_boundary=self.num_boundary_points,
            num_initial=self.num_initial_points,
            num_test=self.num_test_points
        )
        self.model = dde.Model(self.data, net)
        self.model.pde = self
        return self.model
