import torch
from torch import nn
from agents.helpers import SinusoidalPosEmb


class TrajectoryFunction(nn.Module):
    def __init__(self,
                 state_dim,
                 action_dim,
                 model,
                 max_action,
                 n_time_steps = 100,
                 eps = 0.002
                 ):
        super().__init__()

        self.eps = eps
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.model = model

        self.sigma_min = 0.002
        self.sigma_max = 80.0
        self.rho = 7
        self.n_time_steps = n_time_steps
        self.gamma = 0
        self.time_steps = [(self.sigma_max** (1 / self.rho) + i / (n_time_steps-1) * (
            self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))) ** self.rho for i in range(n_time_steps)]
    
    def predict_trajectory_original(self, state, action, t, s):
        if isinstance(t, float):
            t = (
                torch.tensor([t] * action.shape[0], dtype=torch.float32)
                .to(action.device)
                .unsqueeze(1)
            ) # (batch_size, 1)
        if isinstance(s, float):
            s = (
                torch.tensor([s] * action.shape[0], dtype=torch.float32)
                .to(action.device)
                .unsqueeze(1)
            )
        
        action_ori = action 
        print(f"action device: {action.device}, state device: {state.device}, t device: {t.device}, s device: {s.device}")
        print(f"model device: {self.model.device}")
        action = self.model(action_ori, t, state, s)
        print("ok!")


        # sigma_data = 0.5
        t_ = t - self.eps
        c_skip_t = 0.25 / (t_.pow(2) + 0.25) # (batch, 1)
        c_out_t = 0.5 * t_ / (t.pow(2) + 0.25).pow(0.5)
        output = c_skip_t * action_ori + c_out_t * action
        return output

    def predict_trajectory(self, state, action, t, s):
        action_pre = self.predict_trajectory_original(state, action, t, s)
        output = s / t * action + (1 - s / t) * action_pre
        return output
    
    def gamma_sample(self, state):
        action = torch.rand(state.size(0), self.action_dim).to(state.device)
        for i in range(self.n_time_steps - 1):
            t = self.time_steps[i]
            t_next = self.time_steps[i + 1]
            t_next_hat = (1 - self.gamma ** 2) ** 0.5 * t_next
            action_hat = self.predict_trajectory(state, action, t, t_next_hat)
            z = torch.randn_like(action_hat)
            action = action_hat + self.gamma * t_next * z
        
        action.clamp_(-self.max_action, self.max_action)
        return action
    
    def forward(self, state) -> torch.Tensor:
        action = self.gamma_sample(state)
        return action

    def get_last_layer_weight(self):
        return self.model.final_layer.weight  
