import torch
import torch.nn as nn
import numpy as np
import sympy as sp
import math

class SINDyRegression(nn.Module):
    def __init__(self, terms, differential_invariants, **kwargs):
        super().__init__()
        self.input_DI = sp.lambdify(terms, differential_invariants[:-1])
        self.output_DI = sp.lambdify(terms, differential_invariants[-1], modules={'exp': torch.exp})
        self.W = nn.Parameter(torch.randn(1, len(differential_invariants)-1, device=kwargs["device"]))
        self.mask = torch.ones_like(self.W, device=kwargs["device"])

    def forward(self, x):
        x = torch.stack(self.input_DI(**x)).T
        return x @ (self.W * self.mask).T

    def loss(self, x):
        loss_fn = torch.nn.MSELoss()
        y = self.output_DI(**x).unsqueeze(1)
        return loss_fn(self.forward(x), y)

    def set_threshold(self, threshold):
        self.mask.data = torch.logical_and(torch.abs(self.W) > threshold, self.mask).float()

    def print(self):
        print(self.W * self.mask)

    def F(self, x, pde):
        if pde in ['KdV', 'KS', 'Burgers']:
            x_dict = {'t': None, 'x': None, 'u': x[:, 0], 'dudx': x[:, 1], 'dudxdx': x[:, 2], 'dudxdxdx': x[:, 3], 'dudxdxdxdx': x[:, 4], 'dudt': x[:, 5]}
        elif pde == 'nKdV':
            x_dict = {'t': x[:, 0], 'x': None, 'u': x[:, 1], 'dudx': x[:, 2], 'dudxdx': x[:, 3], 'dudxdxdx': x[:, 4], 'dudxdxdxdx': x[:, 5], 'dudt': x[:, 6]}
        return self.forward(x_dict) - self.output_DI(**x_dict).unsqueeze(1)
