import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tools import feature_list
from tools.utils import batch_query_cost_dic

from models.s_model import S_SimDec


class Perturbator:
    """
    Implements group-wise policy training with predictor-guided perturbation for robust decision making.
    This class is compatible with S_SimDec (predictor) and ValueNetwork (policy) models.
    """

    def __init__(
        self,
        predictor: S_SimDec,
        policy: nn.Module,
        env,
        M=8,
        device=None,
        clip_delta=False,
        delta=None,
        cost_dic=None,
        avg_profit=None,
        feature_list=None,
        otr_reward_coeff=1.0,
        retrieve_index=None,
        action_dim=4,
    ):
        """
        Args:
            predictor: The calibrated predictor (S_SimDec), must provide encode, decode, and uncertainty estimation.
            policy: The decision policy network (ValueNetwork or similar), must provide action sampling and log-prob.
            env: Environment/config object for feature extraction, device, etc.
            K: Number of perturbation samples per state.
            alpha: Sharpness for consistency penalty.
            device: torch device.
            clip_delta: Whether to clip perturbation delta within [-delta, delta].
            delta: The maximum absolute value for each element of the perturbation.
        """
        self.predictor = predictor
        self.policy = policy
        self.env = env
        self.K = M
        self.device = device or env.device
        self.clip_delta = clip_delta
        self.delta = delta
        self.cost_dic = cost_dic
        self.avg_profit = avg_profit
        self.feature_list = feature_list
        self.otr_reward_coeff = otr_reward_coeff
        assert retrieve_index is not None, "retrieve_index must not be None"
        self.retrieve_index = retrieve_index
        self.action_dim = action_dim

    def encode(self, s):
        """Encode state s to latent z using predictor's encoder, shape, S_SimDec."""
        # s: (batch, feature_dim)
        p_len = len(feature_list.product_info[self.env.args.dataset])
        o_len = len(feature_list.order_info[self.env.args.dataset])
        c_len = len(feature_list.customer_info[self.env.args.dataset])
        s_len = len(feature_list.shipping_info[self.env.args.dataset])
        p = s[:, :p_len]
        o = s[:, p_len : p_len + o_len]
        c = s[:, p_len + o_len : p_len + o_len + c_len]
        sh = s[:, -s_len:]
        # 通过各自的线性层
        p = self.predictor.c_transform4p(p)
        o = self.predictor.c_transform4o(o)
        c = self.predictor.c_transform4c(c)
        sh = self.predictor.c_transform4s(sh)
        # 拼接 (batch, 1, embed_dim) * 4 -> (batch, 4, embed_dim)
        combined = torch.cat(
            [p.unsqueeze(1), o.unsqueeze(1), c.unsqueeze(1), sh.unsqueeze(1)], dim=1
        )
        # fc
        combined = torch.relu(self.predictor.fc(combined))
        # LSTM编码
        _, (h_n, _) = self.predictor.encoder_lstm(combined)
        h_n_forward = h_n[0 : h_n.size(0) : 2]
        h_n_backward = h_n[1 : h_n.size(0) : 2]
        z = h_n_forward + h_n_backward
        z = z[-1]

        # 计算z的协方差矩阵
        batch_size = z.size(0)
        embed_dim = z.size(1)

        # 基于LSTM隐藏状态计算协方差
        # 使用前向和后向隐藏状态的差异来估计不确定性
        h_forward = h_n_forward[-1]  # (batch, embed_dim)
        h_backward = h_n_backward[-1]  # (batch, embed_dim)

        # 计算隐藏状态的方差作为对角协方差的基础
        h_variance = torch.var(
            torch.stack([h_forward, h_backward]), dim=0
        )  # (batch, embed_dim)

        # 创建对角协方差矩阵，形状为 (batch, embed_dim, embed_dim)
        sigma = torch.zeros(batch_size, embed_dim, embed_dim, device=self.device)
        for i in range(batch_size):
            # 使用隐藏状态方差作为对角元素，添加小的正则化项
            diag_values = h_variance[i] + 0.01  # 添加小的正则化项避免奇异矩阵
            sigma[i] = torch.diag(diag_values)

        return z, sigma

    def decode(self, z):
        """Decode latent z to state using predictor's decoder."""
        # Assumes predictor has a decoder_lstm and can process z as initial hidden state
        # z: (batch, embed_dim)
        batch_size = z.size(0)
        seq_len = 4  # Number of feature groups (product, order, customer, shipping)
        decoder_input = torch.zeros(
            batch_size, 1, self.env.args.embed_dim, device=self.device
        )
        h_0 = z.unsqueeze(0).repeat(self.predictor.decoder_lstm.num_layers, 1, 1)
        c_0 = torch.zeros_like(h_0)
        outputs = []
        hidden = (h_0, c_0)
        for _ in range(seq_len):
            out, hidden = self.predictor.decoder_lstm(decoder_input, hidden)
            outputs.append(out.squeeze(1))
            decoder_input = out  # Teacher forcing not used
        decoded = torch.cat(outputs, dim=1)  # (batch, seq_len*embed_dim)
        return decoded

    def sample_perturbations(self, z, sigma, epsilon_p=1.0):
        """Sample K perturbations for each latent z using Gaussian (cov=sigma), scaled by epsilon_p."""
        batch_size, embed_dim = z.size()
        zs = [z]
        for _ in range(self.K):
            # Scale the random noise by epsilon_p to control perturbation strength
            eps = torch.randn(batch_size, embed_dim, device=self.device) * epsilon_p  # inline comment: scale noise
            delta = torch.bmm(sigma, eps.unsqueeze(2)).squeeze(2)
            zs.append(z + delta)
        return zs  # List of K tensors (batch, embed_dim)

    def inverse_linear(self, y, linear_layer):
        # y: (batch, out_features)
        W = linear_layer.weight  # (out_features, in_features)
        b = linear_layer.bias  # (out_features,)
        y = y - b
        W_pinv = torch.linalg.pinv(W)  # (in_features, out_features)
        x = y @ W_pinv.T  # (batch, in_features)
        return x

    def inverse_transform(self, decoded):
        # 获取每个分组的out_features
        p_out = self.predictor.c_transform4p.out_features
        o_out = self.predictor.c_transform4o.out_features
        c_out = self.predictor.c_transform4c.out_features
        # s_out = self.predictor.c_transform4s.out_features
        # 切分
        p = decoded[:, :p_out]
        o = decoded[:, p_out : p_out + o_out]
        c = decoded[:, p_out + o_out : p_out + o_out + c_out]
        sh = decoded[:, p_out + o_out + c_out :]
        # 用伪逆近似还原
        p_orig = self.inverse_linear(p, self.predictor.c_transform4p)
        o_orig = self.inverse_linear(o, self.predictor.c_transform4o)
        c_orig = self.inverse_linear(c, self.predictor.c_transform4c)
        sh_orig = self.inverse_linear(sh, self.predictor.c_transform4s)
        return torch.cat([p_orig, o_orig, c_orig, sh_orig], dim=1)

    def forward(self, s):
        """
        Perform one group-wise perturbation step and compute all loss components.
        Args:
            s: (batch, feature_dim) input state
        Returns:
            dict with all loss components and statistics
        """
        batch_size = s.size(0)
        # 1. Encode state and get covariance
        z, sigma = self.encode(s)  # (batch, embed_dim), (batch, embed_dim, embed_dim)
        # 3. Baseline prediction
        with torch.no_grad():
            # === baseline特征处理 ===
            p_len = len(feature_list.product_info[self.env.args.dataset])
            o_len = len(feature_list.order_info[self.env.args.dataset])
            c_len = len(feature_list.customer_info[self.env.args.dataset])
            s_len = len(feature_list.shipping_info[self.env.args.dataset])
            p = s[:, :p_len]
            o = s[:, p_len : p_len + o_len]
            c = s[:, p_len + o_len : p_len + o_len + c_len]
            sh = s[:, -s_len:]
            p = self.predictor.c_transform4p(p)
            o = self.predictor.c_transform4o(o)
            c = self.predictor.c_transform4c(c)
            sh = self.predictor.c_transform4s(sh)
            combined = torch.cat(
                [p.unsqueeze(1), o.unsqueeze(1), c.unsqueeze(1), sh.unsqueeze(1)], dim=1
            )
            combined = torch.relu(self.predictor.fc(combined))
            _, (h_n, c_n) = self.predictor.encoder_lstm(combined)
            h_n_forward = h_n[0 : h_n.size(0) : 2]
            h_n_backward = h_n[1 : h_n.size(0) : 2]
            h_n_combined = h_n_forward + h_n_backward
            c_n_combined = c_n[0 : c_n.size(0) : 2] + c_n[1 : c_n.size(0) : 2]
            decoder_hidden = (h_n_combined, c_n_combined)
            SOS_token = torch.full(
                (batch_size, 1), 1, dtype=torch.long, device=self.device
            )
            tgt_embed = self.predictor.embedding(SOS_token)
            decoder_output, _ = self.predictor.decoder_lstm(tgt_embed, decoder_hidden)
            if hasattr(self.predictor, "output_layer"):
                ol = self.predictor.output_layer
                if isinstance(ol, nn.ModuleList):
                    baseline_pred = ol[0](decoder_output.squeeze(1))
                else:
                    baseline_pred = ol(decoder_output.squeeze(1))
            else:
                baseline_pred = decoder_output.squeeze(1)
            hat_s_prime = baseline_pred  # (batch, num_classes)
        # 4. Perturbations
        zs = self.sample_perturbations(z, sigma)  # List of K (batch, embed_dim)
        rewards = []
        log_probs = []
        pred_states = []
        # Prepare to accumulate loss components for all perturbations
        profit_losses = []
        late_losses = []
        # Lists to collect per-perturbation profit and late_risk for group-advantage calculation
        profits_list = []
        late_risk_list = []
        for z_i in zs:
            # Decode perturbed state
            tilde_s_i = self.decode(z_i)  # (batch, seq_len*embed_dim)
            tilde_s_orig = self.inverse_transform(tilde_s_i)
            # Sample action from policy
            logits = self.policy(tilde_s_orig)
            action_dist = torch.distributions.Categorical(logits=logits)
            a_i = action_dist.sample()
            log_prob = action_dist.log_prob(a_i)

            # === profit lookup using utility function ===
            profit = batch_query_cost_dic(
                self.cost_dic,
                self.avg_profit,
                self.retrieve_index,
                tilde_s_orig,
                a_i,
                self.device,
            )
            profits_list.append(profit)

            # === on_time loss calculation (late_loss style) ===
            on_time_target = torch.ones(
                logits.size(0), dtype=torch.long, device=logits.device
            )
            late_loss = F.cross_entropy(logits, on_time_target)
            late_losses.append(late_loss)
            # For group-advantage, we use late_risk as 1 - on_time_pred (late is bad)
            on_time_pred = torch.argmax(logits, dim=1).float()
            late_risk = 1.0 - on_time_pred  # late_risk: 1 if late, 0 if on-time
            late_risk_list.append(late_risk)

            # For reward, use predicted on_time class
            on_time_pred = torch.argmax(logits, dim=1).float()
            log_probs.append(log_prob)
            # === profit_loss calculation (same as dm_train_epoch) ===
            # Encourage selection of high-profit actions
            profits = torch.tensor(self.avg_profit, device=self.device)
            weights = profits / profits.max()
            target_class_decision = (
                torch.argmax(profits).expand(logits.size(0)).to(self.device)
            )
            decision_weights = weights.to(self.device)
            profit_loss = F.cross_entropy(
                F.softmax(logits, dim=-1),
                target_class_decision,
                weight=decision_weights,
            )
            profit_losses.append(profit_loss)
            pred_states.append(tilde_s_i)
        # Stack all loss components for K perturbations
        profit_losses = torch.stack(profit_losses, dim=0)  # (K,)
        late_losses = torch.stack(late_losses, dim=0)  # (K,)
        # rewards = torch.stack(rewards, dim=0)              # (K, batch)
        log_probs = torch.stack(log_probs, dim=0)  # (K, batch)
        profits_all = torch.stack(profits_list, dim=0)  # (K, batch)
        late_risk_all = torch.stack(late_risk_list, dim=0)  # (K, batch)
        # Group-advantage for profit (higher is better)
        bar_profit = profits_all.mean(dim=0)
        profit_adv = (
            ((profits_all - bar_profit.unsqueeze(0)) * log_probs).sum(dim=0).mean()
        )
        # Group-advantage for late risk (lower is better, so use -late_risk)
        bar_late = late_risk_all.mean(dim=0)
        late_adv = (
            ((-late_risk_all + bar_late.unsqueeze(0)) * log_probs).sum(dim=0).mean()
        )
        # Get weights for profit and late loss (same as in dm_train_epoch)
        mip_coeff = getattr(self.env.args, "mip_coeff", 1.0)
        mil_coeff = getattr(self.env.args, "mil_coeff", 1.0)
        # Total group_adv is a weighted sum, using the same weights as in dm_train_epoch
        group_adv = mip_coeff * profit_adv + mil_coeff * late_adv

        # Weighted total loss (same as mi_loss in dm_train_epoch)
        profit_loss_mean = profit_losses.mean()
        late_loss_mean = late_losses.mean()

        return {
            "profit_loss": profit_loss_mean,
            "late_loss": late_loss_mean,
            "group_adv_loss": group_adv,
            "log_probs": log_probs.detach(),
        }
