# refer to consistent actor critic paper implementation: https://github.com/quantumiracle/Consistency_Model_For_Reinforcement_Learning
# https://arxiv.org/abs/2309.16984
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
import operator

# Compatibility fix for older PyTorch versions
class SiLU(nn.Module):
    """SiLU activation function (Swish) for PyTorch compatibility"""
    def forward(self, x):
        return x * torch.sigmoid(x)

# Use built-in SiLU if available, otherwise use our implementation
def get_silu():
    if hasattr(nn, 'SiLU'):
        return nn.SiLU()
    else:
        return SiLU()

# for consistency model
def kerras_boundaries(sigma, eps, N, T):
    # This will be used to generate the boundaries for the time discretization

    return torch.tensor(
        [
            (eps ** (1 / sigma) + i / (N - 1) * (T ** (1 / sigma) - eps ** (1 / sigma)))
            ** sigma
            for i in range(N)
        ]
    )

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

#-----------------------------------------------------------------------------#
#---------------------------------- sampling ---------------------------------#
#-----------------------------------------------------------------------------#


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
    return torch.tensor(betas_clipped, dtype=dtype)


def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2, dtype=torch.float32):
    betas = np.linspace(
        beta_start, beta_end, timesteps
    )
    return torch.tensor(betas, dtype=dtype)


def vp_beta_schedule(timesteps, dtype=torch.float32):
    t = np.arange(1, timesteps + 1)
    T = timesteps
    b_max = 10.
    b_min = 0.1
    alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
    betas = 1 - alpha
    return torch.tensor(betas, dtype=dtype)

#-----------------------------------------------------------------------------#
#---------------------------------- losses -----------------------------------#
#-----------------------------------------------------------------------------#

class WeightedLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, pred, targ, weights=torch.tensor(1.0), take_mean=True):
        '''
            pred, targ : tensor [ batch_size x action_dim ]
        '''
        loss = self._loss(pred, targ)
        if take_mean:
            weighted_loss = (loss * weights.detach()).mean()   
        else:
            weighted_loss = (loss * weights.detach())
        return weighted_loss

class WeightedL1(WeightedLoss):

    def _loss(self, pred, targ):
        return torch.abs(pred - targ)

class WeightedL2(WeightedLoss):

    def _loss(self, pred, targ):
        return F.mse_loss(pred, targ, reduction='none')

class WeightedHuber(WeightedLoss):

    def _loss(self, pred, targ):
        # d = math.prod(pred.shape[1:]) # require Python 3.8
        d = reduce(operator.mul, pred.shape[1:])
        c = 0.00054 * math.sqrt(d)
        return torch.sqrt((pred - targ) ** 2 + c**2) - c



Losses = {
    'l1': WeightedL1,
    'l2': WeightedL2,
    'pseudo_huber': WeightedHuber,
}


class EMA():
    '''
        empirical moving average
    '''
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def set(self, ema):
        self.beta = ema

#-----------------------------------------------------------------------------#
#---------------------------------- models -----------------------------------#
#-----------------------------------------------------------------------------#

class MLP(nn.Module):
    """
    MLP Model
    """
    def __init__(self,
                 state_dim,
                 action_dim,
                 device,
                 t_dim=16):

        super(MLP, self).__init__()
        self.device = device

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 2),
            get_silu(),
            nn.Linear(t_dim * 2, t_dim),
        )

        input_dim = state_dim + action_dim + t_dim
        self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256),
                                       get_silu(),
                                       nn.Linear(256, 256),
                                       get_silu(),
                                       nn.Linear(256, 256),
                                       get_silu())

        self.final_layer = nn.Linear(256, action_dim)

    def forward(self, x, time, state):
        if len(time.shape) > 1:
            time = time.squeeze(1)  # added for shaping t from (batch_size, 1) to (batch_size,)
        t = self.time_mlp(time)
        # print the shape of x, t, state
        # print(f"x shape: {x.shape}, t shape: {t.shape}, state shape: {state.shape}")
        x = torch.cat([x, t, state], dim=1)
        x = self.mid_layer(x)

        return self.final_layer(x)

class ResNetBlock(nn.Module):
    def __init__(self, in_features, hidden_dim, dropout_rate=0.1):
        super(ResNetBlock, self).__init__()
        self.layer = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.LayerNorm(in_features),
            nn.Linear(in_features, 4*hidden_dim),
            nn.ReLU(),
            nn.Linear(4*hidden_dim, hidden_dim)
        )

    def forward(self, x):
        identity = x
        out = self.layer(x)
        out += identity
        return out

class LN_Resnet(nn.Module):
    def __init__(self, state_dim, action_dim, device, t_dim=16, hidden_size=256, dropout_rate=0.1):
        super(LN_Resnet, self).__init__()
        self.device = device

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 2),
            get_silu(),
            nn.Linear(t_dim * 2, t_dim),
        )
        input_dim = state_dim + action_dim + t_dim

        self.input_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_size),
            nn.ReLU(),
        )
        self.resnet_block1 = ResNetBlock(hidden_size, hidden_size, dropout_rate)
        self.resnet_block2 = ResNetBlock(hidden_size, hidden_size, dropout_rate)
        self.resnet_block3 = ResNetBlock(hidden_size, hidden_size, dropout_rate)
        self.output_layer = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_size, action_dim)
        )

    def forward(self, x, time, state):
        if len(time.shape) > 1:
            time = time.squeeze(1)  # added for shaping t from (batch_size, 1) to (batch_size,)
        t = self.time_mlp(time)
        x = torch.cat([x, t, state], dim=1)
        x = self.input_layer(x)
        x = self.resnet_block1(x)
        x = self.resnet_block2(x)
        x = self.resnet_block3(x)
        x = self.output_layer(x)
        return x



blk = lambda ic, oc: nn.Sequential(
    nn.GroupNorm(32, num_channels=ic),
    get_silu(),
    nn.Conv2d(ic, oc, 3, padding=1),
    nn.GroupNorm(32, num_channels=oc),
    get_silu(),
    nn.Conv2d(oc, oc, 3, padding=1),
)

class Unet(nn.Module):
    def __init__(self, 
        n_channel: int,
        D: int = 128,
        device: torch.device = torch.device("cpu"),
        ) -> None:
        super(Unet, self).__init__()
        self.device = device

        self.freqs = torch.exp(
            -math.log(10000) * torch.arange(start=0, end=D, dtype=torch.float32) / D
        )

        self.down = nn.Sequential(
            *[
                nn.Conv2d(n_channel, D, 3, padding=1),
                blk(D, D),
                blk(D, 2 * D),
                blk(2 * D, 2 * D),
            ]
        )

        self.time_downs = nn.Sequential(
            nn.Linear(2 * D, D),
            nn.Linear(2 * D, D),
            nn.Linear(2 * D, 2 * D),
            nn.Linear(2 * D, 2 * D),
        )

        self.mid = blk(2 * D, 2 * D)

        self.up = nn.Sequential(
            *[
                blk(2 * D, 2 * D),
                blk(2 * 2 * D, D),
                blk(D, D),
                nn.Conv2d(2 * D, 2 * D, 3, padding=1),
            ]
        )
        self.last = nn.Conv2d(2 * D + n_channel, n_channel, 3, padding=1)

    def forward(self, x, t) -> torch.Tensor:
        # time embedding
        args = t.float() * self.freqs[None].to(t.device)
        t_emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1).to(x.device)

        x_ori = x

        # perform F(x, t)
        hs = []
        for idx, layer in enumerate(self.down):
            if idx % 2 == 1:
                x = layer(x) + x
            else:
                x = layer(x)
                x = F.interpolate(x, scale_factor=0.5)
                hs.append(x)

            x = x + self.time_downs[idx](t_emb)[:, :, None, None]

        x = self.mid(x)

        for idx, layer in enumerate(self.up):
            if idx % 2 == 0:
                x = layer(x) + x
            else:
                x = torch.cat([x, hs.pop()], dim=1)
                x = F.interpolate(x, scale_factor=2, mode="nearest")
                x = layer(x)

        x = self.last(torch.cat([x, x_ori], dim=1))

        return x
    
# MLP 的残差块，用于构建更深、更稳定的网络
class ResidualBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, cond_dim):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, input_dim)
        self.cond_projection = nn.Linear(cond_dim, hidden_dim)
        self.activation = get_silu()

    def forward(self, x, cond):
        residual = x
        # 将输入 x 和条件 cond 融合
        x = self.activation(self.layer1(x) + self.cond_projection(cond))
        x = self.layer2(x)
        return x + residual

class ResidualVectorMLP(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, t_dim=16, num_blocks=4):
        super().__init__()
        self.action_dim = action_dim

        # 时间步 t 和 s 的嵌入网络 (借鉴自 SongUNet)
        self.time_t_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 4),
            get_silu(),
            nn.Linear(t_dim * 4, t_dim),
        )
        self.time_s_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 4),
            get_silu(),
            nn.Linear(t_dim * 4, t_dim),
        )
        
        # 将 state, action, t, s 融合在一起的初始层
        # 条件的总维度 = state + t + s
        cond_dim = state_dim + t_dim + t_dim
        self.input_projection = nn.Linear(action_dim, hidden_dim)
        
        # 一系列残差块
        self.residual_blocks = nn.ModuleList(
            [ResidualBlock(hidden_dim, hidden_dim * 2, cond_dim) for _ in range(num_blocks)]
        )
        
        # 输出层
        self.output_projection = nn.Linear(hidden_dim, action_dim)

    def forward(self, x, t, s, state, **kwargs):
        # 1. 对时间进行编码
        t_emb = self.time_t_mlp(t)
        s_emb = self.time_s_mlp(s)

        # 2. 准备条件向量
        # 将 state 和时间编码拼接作为条件
        cond = torch.cat([state, t_emb, s_emb], dim=1)

        # 3. 投影输入动作
        x = self.input_projection(x)

        # 4. 通过残差块
        for block in self.residual_blocks:
            x = block(x, cond)

        # 5. 投影到输出维度
        return self.output_projection(x)
    
class ResidualVectorUNet(nn.Module):
    def __init__(self, state_dim, action_dim, dims=[64, 128, 256, 512], t_dim=32):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim

        # 时间步 t 和 s 的嵌入网络 (与 MLP 中相同)
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 4),
            get_silu(),
            nn.Linear(t_dim * 4, t_dim),
        )
        self.time_s_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 4),
            get_silu(),
            nn.Linear(t_dim * 4, t_dim),
        )

        # 初始层，将 (action + state) 映射到第一个维度
        self.init_conv = nn.Linear(action_dim + state_dim, dims[0])
        
        # --- Encoder (下采样路径) ---
        self.downs = nn.ModuleList([])
        for i in range(len(dims) - 1):
            self.downs.append(nn.ModuleList([
                ResidualBlock(dims[i], dims[i]*2, t_dim + t_dim),
                nn.Linear(dims[i], dims[i+1])  # 修复：ResidualBlock输出维度与输入相同
            ]))

        # --- Bottleneck (瓶颈层) ---
        self.mid_block = ResidualBlock(dims[-1], dims[-1]*2, t_dim + t_dim)

        # --- Decoder (上采样路径) ---
        self.ups = nn.ModuleList([])
        for i in reversed(range(len(dims) - 1)):
            self.ups.append(nn.ModuleList([
                # 输入维度是 dims[i+1] + dims[i] (因为有跳跃连接)
                ResidualBlock(dims[i+1] + dims[i], (dims[i+1] + dims[i])*2, t_dim + t_dim),  # 修复：输入维度
                nn.Linear(dims[i+1] + dims[i], dims[i])  # 修复：输入维度
            ]))
        
        # 最终输出层
        self.final_layer = nn.Linear(dims[0], action_dim)

    def forward(self, x, time, time_s, state, **kwargs):
        # 1. 编码时间和拼接条件
        t_emb = self.time_mlp(time)
        s_emb = self.time_s_mlp(time_s)
        time_cond = torch.cat([t_emb, s_emb], dim=1)

        # 2. 将 action 和 state 拼接作为初始输入
        x = torch.cat([x, state], dim=1)
        x = self.init_conv(x)

        # 3. Encoder 路径
        skips = []
        for res_block, downsample in self.downs:
            x = res_block(x, time_cond)
            skips.append(x)
            x = downsample(x)

        # 4. 瓶颈层
        x = self.mid_block(x, time_cond)

        # 5. Decoder 路径
        for res_block, upsample in self.ups:
            # 从 encoder 获取跳跃连接
            skip_connection = skips.pop()
            # 拼接
            x = torch.cat([x, skip_connection], dim=1)
            x = res_block(x, time_cond)
            x = upsample(x)
            
        # 6. 输出
        return self.final_layer(x)


class VectorMLP(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=512, t_dim=16):
        super().__init__()
        self.action_dim = action_dim

        # 简化的时间嵌入，类似原来的 MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 2),
            get_silu(),
            nn.Linear(t_dim * 2, t_dim),
        )
        
        # 直接拼接所有输入，类似原来的 MLP
        input_dim = action_dim + state_dim + t_dim + t_dim  # x + state + t + s
        
        self.mid_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, hidden_dim),
            get_silu(),
            nn.Linear(hidden_dim, hidden_dim),
            get_silu()
        )
        
        self.final_layer = nn.Linear(hidden_dim, action_dim)

    def forward(self, x, t, s, state, **kwargs):
        # 处理时间维度
        if len(t.shape) > 1:
            t = t.squeeze(1)
        if len(s.shape) > 1:
            s = s.squeeze(1)
            
        # 时间嵌入
        t_emb = self.time_mlp(t)
        s_emb = self.time_mlp(s)
        
        # 直接拼接所有输入
        x_input = torch.cat([x, state, t_emb, s_emb], dim=1)
        x = self.mid_layer(x_input)
        
        return self.final_layer(x)

class VectorUNet(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, t_dim=16):
        super().__init__()
        self.action_dim = action_dim
        
        # 时间嵌入
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 2),
            get_silu(),
            nn.Linear(t_dim * 2, t_dim),
        )
        
        input_dim = action_dim + state_dim + t_dim + t_dim
        
        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            get_silu(),
        )
        
        self.encoder2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            get_silu(),
        )
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            get_silu(),
        )
        
        # Decoder with skip connections
        self.decoder2 = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim),  # *4 due to skip connection
            get_silu(),
        )
        
        self.decoder1 = nn.Sequential(
            nn.Linear(hidden_dim * 2, action_dim),  # *2 due to skip connection
        )

    def forward(self, x, time, time_s, state, **kwargs):
        # 处理时间维度
        if len(time.shape) > 1:
            time = time.squeeze(1)
        if len(time_s.shape) > 1:
            time_s = time_s.squeeze(1)
            
        # 时间嵌入
        t_emb = self.time_mlp(time)
        s_emb = self.time_mlp(time_s)
        
        # 拼接输入
        x_input = torch.cat([x, state, t_emb, s_emb], dim=1)
        
        # Encoder with skip connections
        e1 = self.encoder1(x_input)
        e2 = self.encoder2(e1)
        
        # Bottleneck
        b = self.bottleneck(e2)
        
        # Decoder with skip connections
        d2 = self.decoder2(torch.cat([b, e2], dim=1))
        d1 = self.decoder1(torch.cat([d2, e1], dim=1))
        
        return d1


# -----------------------------------------------------------------------------#
# ---------------------------------- data sampler -----------------------------#
# -----------------------------------------------------------------------------#

class Data_Sampler(object):
	def __init__(self, data, device, reward_tune='no'):
		
		self.device = device
		self.state = torch.from_numpy(data['observations']).float().to(self.device)
		self.action = torch.from_numpy(data['actions']).float().to(self.device)
		self.next_state = torch.from_numpy(data['next_observations']).float().to(self.device)
		reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(self.device)
		self.not_done = 1. - torch.from_numpy(data['terminals']).view(-1, 1).float().to(self.device)

		self.size = self.state.shape[0]
		self.state_dim = self.state.shape[1]
		self.action_dim = self.action.shape[1]


		if reward_tune == 'normalize':
			reward = (reward - reward.mean()) / reward.std()
		elif reward_tune == 'iql_antmaze':
			reward = reward - 1.0
		elif reward_tune == 'iql_locomotion':
			reward = iql_normalize(reward, self.not_done)
		elif reward_tune == 'cql_antmaze':
			reward = (reward - 0.5) * 4.0
		elif reward_tune == 'antmaze':
			reward = (reward - 0.25) * 2.0
		self.reward = reward

	def sample(self, batch_size):
		ind = torch.randint(0, self.size, size=(batch_size,))

		return (
			self.state[ind].to(self.device),
			self.action[ind].to(self.device),
			self.next_state[ind].to(self.device),
			self.reward[ind].to(self.device),
			self.not_done[ind].to(self.device)
		)


def iql_normalize(reward, not_done):
	trajs_rt = []
	episode_return = 0.0
	for i in range(len(reward)):
		episode_return += reward[i]
		if not not_done[i]:
			trajs_rt.append(episode_return)
			episode_return = 0.0
	rt_max, rt_min = torch.max(torch.tensor(trajs_rt)), torch.min(torch.tensor(trajs_rt))
	reward /= (rt_max - rt_min)
	reward *= 1000.
	return reward
