from typing import Optional

import torch
import torch.autograd as autograd
import torch.nn as nn
from torch import Tensor

from .mpm import MPMModel, MPMState, MPMStatics
from modules.nclaw.warp import Tape

class DynamicIntegrator(nn.Module):
    def __init__(self, state_dim=24, hidden_dim=32, output_dim=15):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, state):
        # Input:  [batch_size, state_dim] = x(3) + v(3) + F(9) + stress(6)
        # Output: [batch_size, output_dim] = delta_x(3) + delta_v(3) + delta_F(9)
        return self.mlp(state)


class MPMSimFunction(autograd.Function):

    # LXY 

    @staticmethod
    def forward(
            ctx: autograd.function.FunctionCtx,
            model: MPMModel,
            statics: MPMStatics,
            state_curr: MPMState,
            state_next: MPMState,
            x: Tensor,
            v: Tensor,
            C: Tensor,
            F: Tensor,
            stress: Tensor,
            integrator: DynamicIntegrator ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
            
        if integrator is not None:
            batch_size = x.shape[0]
            state = torch.cat([
                x.view(batch_size, -1),
                v.view(batch_size, -1),
                F.view(batch_size, -1),
                stress.view(batch_size, -1)
            ], dim=1)  # [batch_size, 24]
        

        
            delta = integrator(state)  # [batch_size, 15]
            delta_x = delta[:, 0:3]
            delta_v = delta[:, 3:6]
            delta_F = delta[:, 6:15].view(batch_size, 3, 3)
            
            #print(f"delta_x", delta_x[0])
            #print(f"delta_v", delta_v[0])
            #print(f"delta_F", delta_F[0])
            
            #print(f"x", x[0])
            #print(f"v", v[0])
            #print(f"F", F[0])
            
            x_corrected = x + delta_x * 1e-7
            v_corrected = v + delta_v * 1e-7
            F_corrected = F + delta_F * 1e-7
        else:
            x_corrected = x
            v_corrected = v
            F_corrected = F
        # LXY 

        tape: Tape = Tape()
        state_curr.from_torch(x=x_corrected, v=v_corrected, C=C, F=F_corrected, stress=stress)  ## LXY
        model.forward(statics, state_curr, state_next, tape)

        x_next, v_next, C_next, F_next, _ = state_next.to_torch()

        ctx.model = model
        ctx.tape = tape
        ctx.statics = statics
        ctx.state_curr = state_curr
        ctx.state_next = state_next
        
        # LXY
        if integrator is not None:
            ctx.integrator = integrator
            ctx.delta_x = delta_x
            ctx.delta_v = delta_v
            ctx.delta_F = delta_F
        
        
        # LXY

        return x_next, v_next, C_next, F_next

    @staticmethod
    def backward(
            ctx: autograd.function.FunctionCtx,
            grad_x_next: Tensor,
            grad_v_next: Tensor,
            grad_C_next: Tensor,
            grad_F_next: Tensor) -> tuple[None, None, None, None, Tensor, Tensor, Tensor, Tensor, Tensor]:

        model: MPMModel = ctx.model
        tape: Tape = ctx.tape
        statics: MPMStatics = ctx.statics
        state_curr: MPMState = ctx.state_curr
        state_next: MPMState = ctx.state_next
        
        #LXY
        if hasattr(ctx, "integrator"):
            integrator: DynamicIntegrator = ctx.integrator
            delta_x = ctx.delta_x
            delta_v = ctx.delta_v
            delta_F = ctx.delta_F
        else:
            integrator = None
        #LXY

        state_next.from_torch_grad(
            grad_x=grad_x_next,
            grad_v=grad_v_next,
            grad_C=grad_C_next,
            grad_F=grad_F_next)

        model.backward(statics, state_curr, state_next, tape)

        grad_x, grad_v, grad_C, grad_F, grad_stress = state_curr.to_torch_grad()

        if grad_x is not None:
            torch.nan_to_num_(grad_x, 0.0, 0.0, 0.0)
        if grad_v is not None:
            torch.nan_to_num_(grad_v, 0.0, 0.0, 0.0)
        if grad_C is not None:
            torch.nan_to_num_(grad_C, 0.0, 0.0, 0.0)
        if grad_F is not None:
            torch.nan_to_num_(grad_F, 0.0, 0.0, 0.0)
        if grad_stress is not None:
            torch.nan_to_num_(grad_stress, 0.0, 0.0, 0.0)

        return None, None, None, None, grad_x, grad_v, grad_C, grad_F, grad_stress, None


class MPMSim(nn.Module):

    def __init__(self, model: MPMModel, integrator: DynamicIntegrator) -> None:
        super().__init__()
        self.model = model
        self.integrator = integrator

    def state(self, x: Tensor, v: Tensor, C: Tensor, F: Tensor, stress: Tensor, state: Optional[MPMState] = None) -> MPMState:
        model = self.model
        shape = x.size(0)

        if state is None:
            state: MPMState = model.state(shape)
        state.from_torch(x=x, v=v, C=C, F=F, stress=stress)

        return state


class MPMDiffSim(MPMSim):

    def __init__(self, model: MPMModel, integrator: DynamicIntegrator) -> None:
        super().__init__(model, integrator)
        self.integrator = integrator

    def forward(self, statics: MPMStatics, x: Tensor, v: Tensor, C: Tensor, F: Tensor, stress: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        shape = x.size(0)
        state_curr: MPMState = self.model.state(shape)
        state_next: MPMState = self.model.state(shape)
        return MPMSimFunction.apply(self.model, statics, state_curr, state_next, x, v, C, F, stress, self.integrator)


class MPMCacheDiffSim(MPMSim):

    def __init__(self, model: MPMModel, num_steps: int, integrator: DynamicIntegrator) -> None:
        super().__init__(model, integrator)
        self.curr_states = [None for _ in range(num_steps)]
        self.next_states = [None for _ in range(num_steps)]

    def forward(self, statics: MPMStatics, step: int, x: Tensor, v: Tensor, C: Tensor, F: Tensor, stress: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        shape = x.size(0)
        if self.curr_states[step] is None:
            self.curr_states[step] = self.model.state(shape)
        if self.next_states[step] is None:
            self.next_states[step] = self.model.state(shape)
        state_curr = self.curr_states[step]
        state_next = self.next_states[step]
        return MPMSimFunction.apply(self.model, statics, state_curr, state_next, x, v, C, F, stress, self.integrator)


class MPMForwardSim(MPMSim):

    def __init__(self, model: MPMModel) -> None:
        super().__init__(model)

    def forward(self, statics: MPMStatics, state: MPMState) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        model = self.model
        model.forward(statics, state, state, None)
        x_next, v_next, C_next, F_next, _ = state.to_torch()
        return x_next, v_next, C_next, F_next


class MPMExtraSim(MPMSim):

    def __init__(self, model: MPMModel) -> None:
        super().__init__(model)

    def forward(self, statics: MPMStatics, state: MPMState, statics_extra: MPMStatics, state_extra: MPMState) -> Tensor:
        model = self.model
        model.forward_extra(statics, state, statics_extra, state_extra)
        x_extra, _, _, _, _ = state_extra.to_torch()
        return x_extra
