import numpy as np
import os
from bidding_train_env.baseline.q_critic.iql_critic import IQL_Critic
from bidding_train_env.baseline.q_critic.cql_critic import CQL_Critic
from bidding_train_env.strategy.base_strategy import BaseBiddingStrategy
import torch
import pickle





class IQLBiddingCritic(BaseBiddingStrategy):
    """
    IQL Critic (for evaluating the Q value of a given action)
    """
    def __init__(self, budget=100, name="IQL-Critic", cpa=2, category=1,
                 load_dir=None):
        super().__init__(budget, name, cpa, category)

        file_dir = os.path.dirname(os.path.realpath(__file__))
        proj_dir = os.path.dirname(os.path.dirname(file_dir))  # Go back to project root

        # Model and normalization dictionary
        if load_dir is None:
            ckpt_path   = os.path.join(proj_dir, "save_model", "IQL_critic", "iql_critic.pt")
            norm_path   = os.path.join(proj_dir, "save_model", "IQL_critic", "normalize.pkl")
        else:
            ckpt_path   = os.path.join(load_dir, "iql_critic.pt")
            norm_path   = os.path.join(load_dir, "normalize_dict.pkl")

        device = "cuda" if torch.cuda.is_available() else "cpu"

        # Build IQL critic and load weights
        self.critic = IQL_Critic(state_dim=18, act_dim=1)
        self.critic.load_net(ckpt_path, device=device)
        self.critic.to(device)
        self.device = device

        # Record budget related variables
        self.cpa = cpa
        self.budget = budget
        self.category = category
        self.remaining_budget_last = self.budget

    def reset(self):
        self.remaining_budget = self.budget
        self.remaining_budget_last = self.budget

    def _build_state_vector(self, timeStepIndex, pValues, pValueSigmas,
                            historyPValueInfo, historyBid, historyAuctionResult,
                            historyImpressionResult, historyLeastWinningCost,budget,cpa):
        """
        Same 18-dimensional hand-crafted features as original DtBiddingCritic
        """
        time_left = (48 - timeStepIndex) / 48
        budget_left = self.remaining_budget / self.budget if self.budget > 0 else 0

        history_xi= [res[:, 0] for res in historyAuctionResult]
        history_pValue = [res[:, 0] for res in historyPValueInfo]
        history_conversion  = [res[:, 1] for res in historyImpressionResult]

        def safe_mean(arrs):
            return np.mean([np.mean(a) for a in arrs]) if arrs else 0

        historical_xi_mean = safe_mean(history_xi)
        historical_conversion_mean = safe_mean(history_conversion)
        historical_LWC_mean = safe_mean(historyLeastWinningCost)
        historical_pValues_mean = safe_mean(history_pValue)
        historical_bid_mean = safe_mean(historyBid)

        def mean_of_last_n(history, n):
            last = history[max(0, len(history)-n):]
            return 0 if len(last)==0 else np.mean([np.mean(x) for x in last])

        last_three_xi_mean = mean_of_last_n(history_xi, 3)
        last_three_conversion_mean = mean_of_last_n(history_conversion, 3)
        last_three_LWC_mean  = mean_of_last_n(historyLeastWinningCost, 3)
        last_three_pValues_mean = mean_of_last_n(history_pValue, 3)
        last_three_bid_mean = mean_of_last_n(historyBid, 3)

        current_pValues_mean = np.mean(pValues)
        current_pv_num = len(pValues)
        historical_pv_num_total = sum(len(b) for b in historyBid) if historyBid else 0
        last_three_pv_num_total = sum([len(historyBid[i]) for i in range(max(0, timeStepIndex-3), timeStepIndex)]) if historyBid else 0

        state_vec = np.array([
            time_left, budget_left, historical_bid_mean, last_three_bid_mean,
            historical_LWC_mean, historical_pValues_mean, historical_conversion_mean,
            historical_xi_mean, last_three_LWC_mean, last_three_pValues_mean,
            last_three_conversion_mean, last_three_xi_mean,
            current_pValues_mean, current_pv_num, last_three_pv_num_total,
            historical_pv_num_total,budget,cpa
        ], dtype=np.float32)
        return state_vec

    def access_value(self, action, timeStepIndex, pValues, pValueSigmas,
                     historyPValueInfo, historyBid, historyAuctionResult,
                     historyImpressionResult, historyLeastWinningCost,budget,cpa):
        """
        Return (value_1, value_2), shape compatible with upper layer (wrapped as [1,1,1])
        """
        # Construct 18-dimensional state features
        s = self._build_state_vector(timeStepIndex, pValues, pValueSigmas,
                                     historyPValueInfo, historyBid, historyAuctionResult,
                                     historyImpressionResult, historyLeastWinningCost,budget,cpa)

       
        # Convert to tensor
        s_t = torch.tensor(s, dtype=torch.float32, device=self.device).unsqueeze(0)     # [1, 18]
        a_t = torch.tensor(action, dtype=torch.float32, device=self.device).reshape(1, 1)    # [1, 1]

        assece_values  = self.critic.take_critics(s_t,a_t,normalize_indices=[13, 14, 15,16,17])

        return assece_values
    
    def access_value_batch(self, actions_tensor, timeStepIndex, pValues, pValueSigmas,
                           historyPValueInfo, historyBid, historyAuctionResult,
                           historyImpressionResult, historyLeastWinningCost, budget, cpa):
        """
        Input:
        - actions_tensor: Tensor, shape [K, 1], contains K candidate actions
        Returns:
        - (q1_batch, q2_batch): Tensors, each of shape [K, 1], containing Q values for each action
        """
        s = self._build_state_vector(timeStepIndex, pValues, pValueSigmas,
                                     historyPValueInfo, historyBid, historyAuctionResult,
                                     historyImpressionResult, historyLeastWinningCost, budget, cpa)

        # s: [dim] -> s_t: [1, dim]
        s_t = torch.tensor(s, dtype=torch.float32, device=self.device).unsqueeze(0)
    
        K = actions_tensor.shape[0]
        
        s_t_expanded = s_t.repeat(K, 1)

        if actions_tensor.device != self.device:
            actions_tensor = actions_tensor.to(self.device)
            
        q_values = self.critic.take_critics(
            s_t_expanded, 
            actions_tensor, 
            normalize_indices=[13, 14, 15, 16, 17] 
        )
    
        return q_values



class CQLBiddingCritic(BaseBiddingStrategy):
    """
    CQL Critic (for evaluating the Q value of a given action)
    """
    def __init__(self, budget=100, name="CQL-Critic", cpa=2, category=1,
                 load_dir=None):
        super().__init__(budget, name, cpa, category)

        file_dir = os.path.dirname(os.path.realpath(__file__))
        proj_dir = os.path.dirname(os.path.dirname(file_dir))  # Go back to project root

        if load_dir is None:
            ckpt_path   = os.path.join(proj_dir, "save_model", "CQL_critic", "cql_critic.pt")
            norm_path   = os.path.join(proj_dir, "save_model", "CQL_critic", "normalize_dict.pkl")
        else:
            ckpt_path   = os.path.join(load_dir, "cql_critic.pt")
            norm_path   = os.path.join(load_dir, "normalize_dict.pkl")

        device = "cuda" if torch.cuda.is_available() else "cpu"

        self.critic = CQL_Critic(state_dim=18, act_dim=1)
        self.critic.load_net(ckpt_path, device=device)
        self.critic.to(device)
        self.device = device
        self.cpa = cpa
        self.budget = budget
        self.category = category
        self.remaining_budget_last = self.budget

    def reset(self):
        self.remaining_budget = self.budget
        self.remaining_budget_last = self.budget

    def _build_state_vector(self, timeStepIndex, pValues, pValueSigmas,
                            historyPValueInfo, historyBid, historyAuctionResult,
                            historyImpressionResult, historyLeastWinningCost,budget,cpa):

        time_left = (48 - timeStepIndex) / 48
        budget_left = self.remaining_budget / self.budget if self.budget > 0 else 0

        history_xi= [res[:, 0] for res in historyAuctionResult]
        history_pValue = [res[:, 0] for res in historyPValueInfo]
        history_conversion  = [res[:, 1] for res in historyImpressionResult]

        def safe_mean(arrs):
            return np.mean([np.mean(a) for a in arrs]) if arrs else 0

        historical_xi_mean = safe_mean(history_xi)
        historical_conversion_mean = safe_mean(history_conversion)
        historical_LWC_mean = safe_mean(historyLeastWinningCost)
        historical_pValues_mean = safe_mean(history_pValue)
        historical_bid_mean = safe_mean(historyBid)

        def mean_of_last_n(history, n):
            last = history[max(0, len(history)-n):]
            return 0 if len(last)==0 else np.mean([np.mean(x) for x in last])

        last_three_xi_mean = mean_of_last_n(history_xi, 3)
        last_three_conversion_mean = mean_of_last_n(history_conversion, 3)
        last_three_LWC_mean  = mean_of_last_n(historyLeastWinningCost, 3)
        last_three_pValues_mean = mean_of_last_n(history_pValue, 3)
        last_three_bid_mean = mean_of_last_n(historyBid, 3)

        current_pValues_mean = np.mean(pValues)
        current_pv_num = len(pValues)
        historical_pv_num_total = sum(len(b) for b in historyBid) if historyBid else 0
        last_three_pv_num_total = sum([len(historyBid[i]) for i in range(max(0, timeStepIndex-3), timeStepIndex)]) if historyBid else 0

        state_vec = np.array([
            time_left, budget_left, historical_bid_mean, last_three_bid_mean,
            historical_LWC_mean, historical_pValues_mean, historical_conversion_mean,
            historical_xi_mean, last_three_LWC_mean, last_three_pValues_mean,
            last_three_conversion_mean, last_three_xi_mean,
            current_pValues_mean, current_pv_num, last_three_pv_num_total,
            historical_pv_num_total,budget,cpa
        ], dtype=np.float32)
        return state_vec

    def access_value(self, action, timeStepIndex, pValues, pValueSigmas,
                     historyPValueInfo, historyBid, historyAuctionResult,
                     historyImpressionResult, historyLeastWinningCost,budget,cpa):
        """
        Return (value_1, value_2), shape compatible with upper layer (wrapped as [1,1,1])
        """
        # Construct 18-dimensional state features
        s = self._build_state_vector(timeStepIndex, pValues, pValueSigmas,
                                     historyPValueInfo, historyBid, historyAuctionResult,
                                     historyImpressionResult, historyLeastWinningCost,budget,cpa)

       
        # Convert to tensor
        s_t = torch.tensor(s, dtype=torch.float32, device=self.device).unsqueeze(0)     # [1, 18]
        a_t = torch.tensor(action, dtype=torch.float32, device=self.device).reshape(1, 1)    # [1, 1]

        assece_values  = self.critic.take_critics(s_t,a_t,normalize_indices=[13, 14, 15,16,17])

        return assece_values
    
    def access_value_batch(self, actions_tensor, timeStepIndex, pValues, pValueSigmas,
                           historyPValueInfo, historyBid, historyAuctionResult,
                           historyImpressionResult, historyLeastWinningCost, budget, cpa):
        """
        Input:
        - actions_tensor: Tensor, shape [K, 1], contains K candidate actions
        Returns:
        - (q1_batch, q2_batch): Tensors, each of shape [K, 1], containing Q values for each action
        """
        s = self._build_state_vector(timeStepIndex, pValues, pValueSigmas,
                                     historyPValueInfo, historyBid, historyAuctionResult,
                                     historyImpressionResult, historyLeastWinningCost, budget, cpa)

        # s: [dim] -> s_t: [1, dim]
        s_t = torch.tensor(s, dtype=torch.float32, device=self.device).unsqueeze(0)
    
        K = actions_tensor.shape[0]
        
        s_t_expanded = s_t.repeat(K, 1)

        if actions_tensor.device != self.device:
            actions_tensor = actions_tensor.to(self.device)
            
        q_values = self.critic.take_critics(
            s_t_expanded, 
            actions_tensor, 
            normalize_indices=[13, 14, 15, 16, 17] 
        )
    
        return q_values

