# Implementation of Consistency Trajectory Model
# https://arxiv.org/abs/2310.02279.pdf

import math
import numpy as np

import torch
from torch import nn
from agents.helpers import Losses


class Trajectory(nn.Module):
    def __init__(self,
                 state_dim,
                 action_dim,
                 model,
                 max_action,
                 n_time_steps=40,
                 loss_type='l2',
                 clip_denoised=True,
                 action_norm=False,
                 sigma_min=0.002,
                 sigma_max=80.0,
                 rho=7,
                 gamma=0,
                 ) -> None:
        super(Trajectory, self).__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.model = model
        self.clip_denoised = clip_denoised
        self.action_norm = action_norm

        # parameters for sampling
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.rho = rho

        self.n_time_steps = n_time_steps
        self.gamma = gamma

        self.t_seq = [(self.sigma_max** (1 / self.rho) + i / (n_time_steps) * (
            self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))) ** self.rho for i in range(n_time_steps + 1)]

        self.loss_fn = Losses[loss_type]()
    
    def predict_g(self, state, action, t, s) -> torch.Tensor:
        if isinstance(t, float):
            t = (
                torch.tensor([t] * action.shape[0], dtype=torch.float32)
                .to(action.device)
                .unsqueeze(1)
            ) # (batch_size, 1)
        if isinstance(s, float):
            s = (
                torch.tensor([s] * action.shape[0], dtype=torch.float32)
                .to(action.device)
                .unsqueeze(1)
            )
        
        action_ori = action 
        action = self.model(action_ori, t, state, s)

        # sigma_data = 0.5
        t_ = t - self.sigma_min
        c_skip_t = 0.25 / (t_.pow(2) + 0.25) # (batch, 1)
        c_out_t = 0.5 * t_ / (t.pow(2) + 0.25).pow(0.5)
        output = c_skip_t * action_ori + c_out_t * action
        if self.action_norm:
            output = self.max_action * torch.tanh(output) # normalization
        return output

    def predict_G(self, state, action, t, s):
        action_pre = self.predict_g(state, action, t, s)
        output = s / t * action + (1 - s / t) * action_pre
        return output
    
    def derivative(self, state, action_t, t):
        denoised_action = self.predict_g(state, action_t, t, t)
        return (action_t - denoised_action) / t
    
    def heun_solver(self, state, action, t, s, num_steps=1):
        time_points = [(t ** (1 / self.rho) + i / num_steps * (s ** (1 / self.rho) - t ** (1 / self.rho))) ** self.rho for i in range(num_steps + 1)]

        for i in range(num_steps):
            t_i = time_points[i]
            t_i_next = time_points[i + 1]
            d = self.derivative(state, action, t_i)
            action_next = action + (t_i_next - t_i) * d
            d_next = self.derivative(state, action_next, t_i_next)
            action_next = action + (t_i_next - t_i) / 2 * (d + d_next)
        return action_next

    def loss_ctm(self, state, action, z, t, u, s, ema_model, weights=None):
        action_t = action + t * z
        if self.action_norm:
            action_t = self.max_action * torch.tanh(action_t)
  
        ## student prediction
        action_s_est = self.predict_G(state, action_t, t, s)
        action_est = ema_model.predict_G(state, action_s_est, s, self.sigma_min)
        
        ## teacher prediction
        with torch.no_grad():
            action_u = action + u * z
            action_s_target = ema_model.predict_G(state, action_u, u, s)
            action_target = ema_model.predict_G(state, action_s_target, s, self.sigma_min)

        loss = self.loss_fn(action_est, action_target, weights=weights)

        return loss
    
    def loss_dsm(self, state, action, z, t, weights=None):
        action_t = action + t * z
        if self.action_norm:
            action_t = self.max_action * torch.tanh(action_t)
        action_est = self.predict_g(state, action_t, t, t)

        loss = self.loss_fn(action_est, action, weights=weights)
        return loss
    

    def gamma_sample(self, state):
        action = torch.randn(state.size(0), self.action_dim).to(state.device) * self.sigma_max 
        for i in range(self.n_time_steps):
            t = self.t_seq[i]
            t_next = self.t_seq[i + 1]
            t_next_hat = (1 - self.gamma ** 2) ** 0.5 * t_next
            action_hat = self.predict_G(state, action, t, t_next_hat)
            if self.action_norm:
                action_hat = self.max_action * torch.tanh(action_hat)
            z = torch.randn_like(action_hat)
            action = action_hat + self.gamma * t_next * z
            if self.action_norm:
                action = self.max_action * torch.tanh(action)
        
        action.clamp_(-self.max_action, self.max_action)
        return action
    
    def forward(self, state) -> torch.Tensor:
        action = self.gamma_sample(state)
        return action

    def get_last_layer_weight(self):
        return self.model.final_layer.weight
