from typing import Literal, Union, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchode
from torchdiffeq import odeint, odeint_adjoint

"""Code is heavily inspired by: 
https://github.com/rtqichen/torchdiffeq/blob/master/examples/ode_demo.py
and
https://github.com/martenlienen/torchode/blob/main/docs/torchdiffeq.ipynb
"""


class Torchode_solver():
    
    def __init__(self, odefunc):
        self.term = torchode.ODETerm(odefunc)
        self.step_method = torchode.Dopri5(term=self.term)
        self.step_size_controller = torchode.IntegralController(atol=1e-6, rtol=1e-3, term=self.term)
        self.solver = torchode.AutoDiffAdjoint(self.step_method, self.step_size_controller)
        self.jit_solver = torch.compile(self.solver)
        
    def odeint(self, y0, t):
        sol = self.jit_solver.solve(torchode.InitialValueProblem(y0=y0, t_eval=t))
        return sol.ys


class LinearNODE(nn.Module):

    def __init__(self, early_stopping_threshold: float = 1e-5):
        super().__init__()
        self.net = None
        self.early_stopping_threshold = early_stopping_threshold

    def forward(self, t, y) -> torch.Tensor:
        return self.net(y)
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    
    def _prepare(self, system_dim: int):
        self.net = nn.Linear(system_dim, system_dim, bias=False)
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)

    def fit(
        self,
        xs: np.ndarray,
        t: np.ndarray,
        niters: int = 10000,
        print_freq: int = 50,
        learning_rate: float = 1e-3, 
        reg_lambda: float = 1e-3,
        scale_dim: bool = False, 
        optimizer: Literal["adam", "rmsprop"] = "rmsprop",
        ode_integration_method: Literal["dopri5", "adams"] = "dopri5",
        solver: Union[None, Literal["torchode", "adjoint"]] = "torchode",
    ) -> None:
        """Fit model to given trajectory.

        Args:
            xs (np.ndarray): should have shape (time, state_variables)
            t (np.ndarray): should have shape (time,)
            niters (int, optional): _description_. Defaults to 1000.
            test_freq (int, optional): _description_. Defaults to 1.
        """
        
        print(f"scale_dim: {scale_dim}")
        print(f"reg_lambda: {reg_lambda}")
        print(f"xs.shape: {xs.shape}")
        print(f"interval [{t[0]}, {t[-1]}], t.shape: {t.shape}")

        cuda_available = torch.cuda.is_available()

        print(f"CUDA Available: {cuda_available}")

        if cuda_available:
            num_gpus = torch.cuda.device_count()
            print(f"Number of GPUs available: {num_gpus}")
            for i in range(num_gpus):
                print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        else:
            print("CUDA is not available. Using CPU.")
        
        xs = torch.from_numpy(xs)
        t = torch.from_numpy(t)
        self._prepare(system_dim=xs.shape[1])
        if optimizer == "adam":
            optimizer = optim.Adam(self.net.parameters(), lr=learning_rate)
        elif optimizer == "rmsprop":
            optimizer = optim.RMSprop(self.net.parameters(), lr=learning_rate)
            
        if solver == "torchode":
            torchode_solver = Torchode_solver(self)
            
        dim = xs.shape[1]
        for itr in range(1, niters + 1):
            optimizer.zero_grad()
            if solver is None:
                pred_y = odeint(self, xs[0], t, method=ode_integration_method)
            elif solver == "adjoint":
                pred_y = odeint_adjoint(self, xs[0], t, method=ode_integration_method)
            elif solver == "torchode":
                pred_y = torchode_solver.odeint(xs[0].reshape(1, -1), t.reshape(1, -1), )[0]
            
            
            data_loss = torch.mean(torch.abs(pred_y - xs))
        
            l1_reg = torch.norm(self.net.weight, p=1) 
            
            self.loss = data_loss + reg_lambda * l1_reg
            
            self.loss.backward()
            optimizer.step()
            try:
                if itr % print_freq == 0:
                    print(f"\riteration {itr:06d} | loss = {self.loss.item()}", end="")
            except ZeroDivisionError:
                pass
            
            if self.loss.item() < self.early_stopping_threshold:
                break
            
        print(f"\rOptimization finished after {itr} / {niters} steps with final loss = {self.loss.item()} (early stopping threshold = {self.early_stopping_threshold})")
                
    def get_system_matrix(self) -> np.ndarray:
        params = list(self.parameters())
        assert len(params) == 1, f"There should only be a single parameter, the estimated system matrix. However, found {len(params)} parameters."
        return params[0].detach().numpy()
    
    def get_info(self) -> Dict:
        if hasattr(self, "loss"):
            optimization_error = self.loss.item()
        else:
            optimization_error = np.nan
        return {
            "name": "LinearNODE",
            "optimization_error": optimization_error,
        }
        
        

if __name__ == '__main__':
    import scipy.integrate
    model = LinearNODE()
    t = np.linspace(0, 5, 256, endpoint=True, dtype=np.float32)
    # ode = lambda x, t: np.array([-0.5*x[0],  -0.1*x[1]])
    ode = lambda x, t: np.array([-0.5*x[0] + 0.2*x[1],  -0.1*x[1]])
    xs = scipy.integrate.odeint(
        func=ode,
        y0=np.array([3, 2], dtype=np.float32),
        t=np.linspace(0, 4, 256, endpoint=True, dtype=np.float32),
    ).astype(np.float32)
    
    model.fit(xs, t, solver="torchode", print_freq=10, optimizer="rmsprop", niters=1000)
    print(model.get_system_matrix())
    print(model.get_info())