import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import utils
from agent.encoder import make_encoder
from torch.distributions import Bernoulli
from timm.models.vision_transformer import PatchEmbed, Block
from torch.distributions import normal
from functools import partial

# 定义一个自定义层
class RepeatFeature(nn.Module):
    def __init__(self, repeat_times):
        super(RepeatFeature, self).__init__()
        self.repeat_times = repeat_times

    def forward(self, x):
        # 假设输入形状为 [batch_size, 1, feature_size]
        # 重复第二维度（尺寸为1的维度）repeat_times次
        return x.repeat(1, self.repeat_times, 1)


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)


LOG_FREQ = 10000

def gaussian_logprob(noise, log_std):
    """Compute Gaussian log probability."""
    residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
    return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)


def squash(mu, pi, log_pi):
    """Apply squashing function.
    See appendix C from https://arxiv.org/pdf/1812.05905.pdf.
    """
    mu = torch.tanh(mu)
    if pi is not None:
        pi = torch.tanh(pi)
    if log_pi is not None:
        log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
    return mu, pi, log_pi


class Actor(nn.Module):
    """MLP actor network."""

    def __init__(
            self, obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters, stride, device
    ):
        super().__init__()

        self.encoder = make_encoder(
            encoder_type, obs_shape, encoder_feature_dim, num_layers,
            num_filters, stride
        )

        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.trunk = nn.Sequential(
            nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2 * action_shape[0])
        )
        self.Step = 0
        self.outputs = dict()
        self.apply(weight_init)

    def init_actor(self):
        self.apply(weight_init)

    def forward(
            self, obs, pre_obs, pre_action, VLM_action, compute_pi=True, compute_log_pi=True, detach_encoder=False,
            detach_trunk=False, detach_weight=False, compute_std_weight=True
    ):
        obs_feature = self.encoder(obs, detach=detach_encoder)
        mu, log_std = self.trunk(obs_feature).chunk(2, dim=-1)

        # constrain log_std inside [log_std_min, log_std_max]
        log_std = torch.tanh(log_std)
        log_std = self.log_std_min + 0.5 * (
                self.log_std_max - self.log_std_min
        ) * (log_std + 1)

        self.outputs['mu'] = mu
        self.outputs['std'] = log_std.exp()

        if compute_pi:
            #
            #
            std = log_std.exp()
            noise = torch.randn_like(mu)
            pi = mu + noise * std

            output_weight = None
        else:
            pi = None
            output_weight = None
            log_std_weight = None
            entropy = None

        if compute_log_pi:
            # 计算动作的对数概率
            log_pi = gaussian_logprob(noise, log_std)
        else:
            log_pi = None

        # 使用squash函数对mu、pi和log_pi进行裁剪，确保动作值在合理的范围内。
        mu, pi, log_pi = squash(mu, pi, log_pi)
        self.Step = self.Step + 1
        return mu, pi, log_pi, log_std, obs_feature, output_weight

    def log(self, L, step, log_freq=LOG_FREQ):
        if step % log_freq != 0:
            return

        for k, v in self.outputs.items():
            L.log_histogram('train_actor/%s_hist' % k, v, step)

        # L.log_param('train_actor/fc1', self.trunk[0], step)
        # L.log_param('train_actor/fc2', self.trunk[2], step)
        # L.log_param('train_actor/fc3', self.trunk[4], step)


class QFunction(nn.Module):
    """MLP for q-function."""

    def __init__(self, obs_dim, action_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, obs, action):
        assert obs.size(0) == action.size(0)

        obs_action = torch.cat([obs, action], dim=1)
        return self.trunk(obs_action)


class Critic(nn.Module):
    """Critic network, employes two q-functions."""

    def __init__(
            self, obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, stride
    ):
        super().__init__()

        self.encoder = make_encoder(
            encoder_type, obs_shape, encoder_feature_dim, num_layers,
            num_filters, stride
        )

        self.Q1 = QFunction(
            self.encoder.feature_dim, action_shape[0], hidden_dim
        )
        self.Q2 = QFunction(
            self.encoder.feature_dim, action_shape[0], hidden_dim
        )

        self.outputs = dict()
        self.apply(weight_init)

    def forward(self, obs, action, detach_encoder=False):
        # detach_encoder allows to stop gradient propogation to encoder
        obs = self.encoder(obs, detach=detach_encoder)

        q1 = self.Q1(obs, action)
        q2 = self.Q2(obs, action)

        self.outputs['q1'] = q1
        self.outputs['q2'] = q2

        return q1, q2

    def log(self, L, step, log_freq=LOG_FREQ):
        if step % log_freq != 0:
            return

        self.encoder.log(L, step, log_freq)

        for k, v in self.outputs.items():
            L.log_histogram('train_critic/%s_hist' % k, v, step)

        for i in range(3):
            L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step)
            L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step)


# ICM（内在好奇心模块）
class ICM(nn.Module):
    def __init__(self, obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, stride):

        super(ICM, self).__init__()
        # 状态编码器
        self.encoder = make_encoder(
            encoder_type, obs_shape, encoder_feature_dim, num_layers,
            num_filters, stride
        )

        # 逆模型：根据状态变化预测动作
        self.inverse_model = nn.Sequential(
            nn.Linear(encoder_feature_dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, action_shape[0])
        )

        # 前向模型：根据状态和动作预测下一个状态
        self.forward_model = nn.Sequential(
            nn.Linear(encoder_feature_dim + action_shape[0], 256),
            nn.ReLU(),
            nn.Linear(256, encoder_feature_dim)
        )

    def forward(self, state, next_state, action, detach_encoder=False):
        # 状态编码
        encoded_state = self.encoder(state, detach=False)
        encoded_next_state = self.encoder(next_state, detach=False)

        # 预测动作和下一个状态
        predicted_action = self.inverse_model(torch.cat([encoded_state, encoded_next_state], dim=1))

        predicted_next_state = self.forward_model(torch.cat([encoded_state, action], dim=1))

        return predicted_action, predicted_next_state, encoded_next_state