import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import math
import logging
from bidding_train_env.strategy import PlayerBiddingStrategy,PlayerBiddingCritic
from bidding_train_env.dataloader.test_dataloader import TestDataLoader
from bidding_train_env.environment.offline_env import OfflineEnv

import torch
import torch.distributions as D
os.chdir(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] [%(name)s] [%(filename)s(%(lineno)d)] [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)



def candidate_bids_from_gmm_more(pi, mu, sigma, 
                                 taus=(0.5, 0.75, 1.0, 1.25, 1.5), 
                                 risk_lams=(0.0, 0.1, 0.25, 0.5),
                                 sample_n=None, clip=None):
    """
    candidate bids from GMM distribution
    """
    # 1. 统一维度 [M]
    if pi.dim() > 1:
        pi = pi.reshape(-1, pi.size(-1))[-1]
        mu = mu.reshape(-1, mu.size(-1))[-1]
        sigma = sigma.reshape(-1, sigma.size(-1))[-1]
    
    device = pi.device
    M = pi.numel()
    eps = 1e-12
    taus_t = torch.tensor(taus, device=device, dtype=pi.dtype).view(-1, 1)    # [T, 1]
    lams_t = torch.tensor(risk_lams, device=device, dtype=pi.dtype).view(1, -1) # [1, L]
    logits = torch.log(pi + eps).view(1, -1) 
    logits_tau = logits / (torch.max(taus_t, torch.tensor(1e-6, device=device)))
    pi_tau = torch.softmax(logits_tau, dim=-1) # [T, M]
    mu_mix = (pi_tau * mu.view(1, -1)).sum(dim=-1) 
    mu2_sigma2 = mu.pow(2) + sigma.pow(2) 
    sec_moment = (pi_tau * mu2_sigma2.view(1, -1)).sum(dim=-1) 
    var_mix = (sec_moment - mu_mix.pow(2)).clamp(min=0.0)
    std_mix = torch.sqrt(var_mix + eps) 
    candidates_matrix = mu_mix.view(-1, 1) - lams_t * std_mix.view(-1, 1)
    candidates_flat = candidates_matrix.flatten() 

    comp_modes = mu
    comp_conservative = mu - 0.5 * sigma
    all_cands = torch.cat([
        candidates_flat, 
        comp_modes, 
        comp_conservative,
        (pi * mu).sum().view(1) 
    ])
    if clip is not None:
        all_cands = all_cands.clamp(*clip)

    unique_cands = torch.unique(all_cands.round(decimals=4))
    if sample_n is not None and unique_cands.numel() > sample_n:
        baseline = (pi * mu).sum()
        dists = torch.abs(unique_cands - baseline)
        _, idx = torch.topk(dists, k=sample_n, largest=False) 
        return unique_cands[idx]
    
    return unique_cands



def getScore_nips(reward, cpa, cpa_constraint):
    beta = 2
    penalty = 1
    if cpa > cpa_constraint:
        coef = cpa_constraint / (cpa + 1e-10)
        penalty = pow(coef, beta)
    return penalty * reward

def run_test(policy_load_dir=None,critic_load_dir =None, policy_method='dt_dist', reweigth_w=0.2, budget_coef=1.0,file_path=None,K=None,rK=None):
    """
    Offline evaluation
    """
    data_loader = TestDataLoader(file_path=file_path)
    env = OfflineEnv()

    keys, test_dict = data_loader.keys, data_loader.test_dict
    
    overall_score = 0.0
    exceed_rate = 0.0
    overall_reward = 0.0
    cpa_ratio = 0.0

    total_decisions = 0
    source_stats = {
        "gen": 0,   # From GMM generation
        "ret": 0,   # From retrieval
        "fallback": 0
    }
 
    for key in keys:
        num_timeStepIndex, pValues, pValueSigmas, leastWinningCosts, budget, cpa, category= data_loader.mock_data(key)

        budget = budget_coef * budget
        agent = PlayerBiddingStrategy(budget=budget, cpa=cpa, load_dir=policy_load_dir, baseline_method=policy_method, reweight_w=reweigth_w)
        critic = PlayerBiddingCritic(budget=budget,cpa=cpa, load_dir=critic_load_dir)
        rewards = np.zeros(num_timeStepIndex)
        history = {
            'historyBids': [],
            'historyAuctionResult': [],
            'historyImpressionResult': [],
            'historyLeastWinningCost': [],
            'historyPValueInfo': []
        }
        actual_excuted_action = None

        for timeStep_index in range(num_timeStepIndex):
            pValue = pValues[timeStep_index]
            pValueSigma = pValueSigmas[timeStep_index]
            leastWinningCost = leastWinningCosts[timeStep_index]

            if agent.remaining_budget < env.min_remaining_budget:
                bid = np.zeros(pValue.shape[0])
                source_stats["fallback"] += 1 # Record fallback
            else:
                bid,retrieved_actions, alpha, pi, mu, sigma = agent.bidding(timeStep_index, pValue, pValueSigma, history["historyPValueInfo"],
                                        history["historyBids"],
                                        history["historyAuctionResult"], history["historyImpressionResult"],
                                        history["historyLeastWinningCost"],
                                        actual_excuted_action=actual_excuted_action,retrieved_K=rK)
                # Obtain generated actions (GMM) - ensure conversion to list
                gen_candidates_tensor = candidate_bids_from_gmm_more(
                    pi, mu, sigma, sample_n=K,
                )
                gen_actions_list = gen_candidates_tensor.detach().cpu().numpy().reshape(-1).tolist()
                ret_actions_list = np.asarray(retrieved_actions).reshape(-1).tolist()
                split_index = len(gen_actions_list)
                actions = gen_actions_list + ret_actions_list

                device = "cuda:0" if torch.cuda.is_available() else "cpu"
                flat_actions = [a.item() if isinstance(a, torch.Tensor) else a for a in actions]
                actions_tensor = torch.tensor(flat_actions, dtype=torch.float32, device=device).reshape(-1, 1)
                with torch.no_grad():
                    q1_batch, q2_batch = critic.access_value_batch(
                        actions_tensor, 
                        timeStep_index, pValue, pValueSigma, 
                        history["historyPValueInfo"],
                        history["historyBids"],
                        history["historyAuctionResult"],
                        history["historyImpressionResult"],
                        history["historyLeastWinningCost"],
                        budget, cpa
                    )
                    q_values = torch.min(q1_batch, q2_batch).flatten().cpu().numpy() # [K]
                max_idx = int(np.argmax(q_values))
                max_action = flat_actions[max_idx] 

                total_decisions += 1
                if max_idx < split_index:
                    source_stats["gen"] += 1
                else:
                    source_stats["ret"] += 1
                actual_excuted_action = np.array([max_action], dtype=np.float32)
                bid = max_action * pValue
                

            tick_value, tick_cost, tick_status, tick_conversion = env.simulate_ad_bidding(pValue, pValueSigma, bid,
                                                                                        leastWinningCost)

            # Handling over-cost (a timestep costs more than the remaining budget of the bidding advertiser)
            over_cost_ratio = max((np.sum(tick_cost) - agent.remaining_budget) / (np.sum(tick_cost) + 1e-4), 0)
            while over_cost_ratio > 0:
                pv_index = np.where(tick_status == 1)[0]
                dropped_pv_index = np.random.choice(pv_index, int(math.ceil(pv_index.shape[0] * over_cost_ratio)),
                                                    replace=False)
                bid[dropped_pv_index] = 0
                tick_value, tick_cost, tick_status, tick_conversion = env.simulate_ad_bidding(pValue, pValueSigma, bid,
                                                                                            leastWinningCost)
                over_cost_ratio = max((np.sum(tick_cost) - agent.remaining_budget) / (np.sum(tick_cost) + 1e-4), 0)

            agent.remaining_budget -= np.sum(tick_cost)
            rewards[timeStep_index] = np.sum(tick_conversion)
            temHistoryPValueInfo = [(pValue[i], pValueSigma[i]) for i in range(pValue.shape[0])]
            history["historyPValueInfo"].append(np.array(temHistoryPValueInfo))
            history["historyBids"].append(bid)
            history["historyLeastWinningCost"].append(leastWinningCost)
            temAuctionResult = np.array(
                [(tick_status[i], tick_status[i], tick_cost[i]) for i in range(tick_status.shape[0])])
            history["historyAuctionResult"].append(temAuctionResult)
            temImpressionResult = np.array([(tick_conversion[i], tick_conversion[i]) for i in range(pValue.shape[0])])
            history["historyImpressionResult"].append(temImpressionResult)
        all_reward = np.sum(rewards)
        all_cost = agent.budget - agent.remaining_budget
        cpa_real = np.clip(all_cost / (all_reward + 1e-10), a_min=0, a_max=100.)
        cpa_constraint = agent.cpa
        score = getScore_nips(all_reward, cpa_real, cpa_constraint)
        overall_score += score
        overall_reward += all_reward
        if cpa_real > cpa_constraint:
            exceed_rate += 1
        cpa_ratio += cpa_real / cpa_constraint

        logger.info(f'Total Reward: {all_reward}')
        logger.info(f'Total Cost: {all_cost}')
        logger.info(f'CPA-real: {cpa_real}')
        logger.info(f'CPA-constraint: {cpa_constraint}')
        logger.info(f'Score: {score}')
        curr_total = source_stats["gen"] + source_stats["ret"] + 1e-10
        logger.info(f'[Source Tracking] Gen: {source_stats["gen"]} ({source_stats["gen"]/curr_total:.2%}) | '
                    f'Ret: {source_stats["ret"]} ({source_stats["ret"]/curr_total:.2%})')
        
    final_total = source_stats["gen"] + source_stats["ret"] + 1e-10
    logger.info(f'===========> overall average score: {overall_score/len(keys)}  exceed rate: {exceed_rate/len(keys)} total reward: {overall_reward/len(keys)} cpa_ratio:{cpa_ratio/len(keys)}')
    logger.info(f'Action Source Distribution:')
    logger.info(f'  - Generator (GMM): {source_stats["gen"]} actions ({source_stats["gen"]/final_total:.2%})')
    logger.info(f'  - Retrieval (RAG): {source_stats["ret"]} actions ({source_stats["ret"]/final_total:.2%})')
    
    return overall_score/len(keys), exceed_rate/len(keys), overall_reward/len(keys)

if __name__ == '__main__':
    import argparse
    default_policy_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "saved_model", "dt_dist")
    default_critic_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "saved_model", "CQL_critic")
    parser = argparse.ArgumentParser()
    parser.add_argument("--baseline_method", type=str, default="dt_dist")
    parser.add_argument("--policy_ckpt", type=str,
                        default=default_policy_path)
    parser.add_argument("--critic_ckpt", type=str,
                        default=default_critic_path)
    parser.add_argument("--periods", type=int, nargs="+", default=[7],
                        help="Traffic-sparse period ids, e.g. 7 8 9 10")
    parser.add_argument("--budget-coef", type=float, default=1.0,
                        help="Multiplicative coefficient on budget")
    parser.add_argument("--K", type=int, default=5,
                        help="number of candidates to generate")
    parser.add_argument("--rK", type=int, default=5,
                        help="number of candidates to retrieve")

    args = parser.parse_args()

    baseline_method = args.baseline_method
    policy_load_dir = args.policy_ckpt
    critic_load_dir = args.critic_ckpt

    for period in args.periods:
        current_dir = os.path.dirname(os.path.abspath(__file__))
        file_path = os.path.join(current_dir, f"../../data/traffic/period-{period}.csv")
        print(f"evaluating method: {baseline_method} with ckpt from {policy_load_dir}")
        print(f"\n testing on {file_path}... (budget_coef={args.budget_coef}) ===============>")
        avg_score, avg_exceed_rate, avg_reward = run_test(
            policy_load_dir=policy_load_dir,
            critic_load_dir=critic_load_dir,
            policy_method=baseline_method,
            reweigth_w=0.2,
            file_path=file_path,
            budget_coef=args.budget_coef,
            K=args.K,
            rK=args.rK
        )