# evaluation_offline.py
import numpy as np, math, gc
import torch
import os

from offline_eval.test_dataloader import TestDataLoader
from offline_eval.offline_env import OfflineEnv


def getScore_neurips(reward, cpa, cpa_constraint):
    beta = 2
    if cpa > cpa_constraint:
        coef = cpa_constraint / (cpa + 1e-10)
        return (coef**beta) * reward
    return reward


class Evaluation(object):
    def __init__(self, config, state_mean, state_std):
        self.config = config
        self.device = config.get('device', 'cpu')
        self.num_eval_episodes = config.get('num_eval_episodes', 1)
        self.max_ep_len = config.get('max_ep_len', 48)
        self.test_csv_list = config.get('test_csv_list', ['./data/traffic_test7/period-7.csv'])
        self.state_mean = state_mean
        self.state_std  = state_std
        self.state_dim = 16

        self.is_stitch  = bool(config.get('is_stitch', False))
        self.percentile = config.get('percentile', 1)
        self.scale      = float(config.get('scale', 1.0))
        self.use_mean_topk = bool(config.get('use_mean_topk', False))

    def _build_state_vec(self, t, budget, remaining_budget,
                         pValues, pValueSigmas,
                         histPValueInfo, histBids, histAuction, histImpress, histLWC):
        time_left   = (self.max_ep_len - t) / self.max_ep_len
        budget_left = remaining_budget / (budget + 1e-10)

        history_xi         = [r[:, 0] for r in histAuction]
        history_pValue     = [r[:, 0] for r in histPValueInfo]
        history_conversion = [r[:, 1] for r in histImpress]

        def mean_hist(x): return np.mean([np.mean(a) for a in x]) if x else 0.0
        def mean_last_n(x, n):
            seg = x[max(0, n-3):n]
            return 0.0 if len(seg)==0 else np.mean([np.mean(a) for a in seg])

        his_xi_m   = mean_hist(history_xi)
        his_conv_m = mean_hist(history_conversion)
        his_lwc_m  = mean_hist(histLWC)
        his_p_m    = mean_hist(history_pValue)
        his_bid_m  = mean_hist(histBids)

        last3_xi_m   = mean_last_n(history_xi, t)
        last3_conv_m = mean_last_n(history_conversion, t)
        last3_lwc_m  = mean_last_n(histLWC, t)
        last3_p_m    = mean_last_n(history_pValue, t)
        last3_bid_m  = mean_last_n(histBids, t)

        cur_p_mean = float(np.mean(pValues)) if len(pValues) else 0.0
        cur_p_num  = int(len(pValues))
        his_pv_total   = sum(len(b) for b in histBids) if histBids else 0
        last3_pv_total = sum(len(histBids[i]) for i in range(max(0, t-3), t)) if histBids else 0

        state = np.array([
            time_left, budget_left, his_bid_m, last3_bid_m,
            his_lwc_m, his_p_m, his_conv_m, his_xi_m,
            last3_lwc_m, last3_p_m, last3_conv_m, last3_xi_m,
            cur_p_mean, cur_p_num, last3_pv_total, his_pv_total
        ], dtype=np.float32)
        return state

    def _topk_reduce(self, sample_returns: torch.Tensor):
        flat = sample_returns.reshape(-1)
        if flat.numel() == 0:
            return None, None
        sorted_ret, _ = torch.sort(flat, descending=True)
        max_ret = sorted_ret[0]

        if not self.use_mean_topk:
            return max_ret, None

        K = self.percentile
        if isinstance(K, float):
            K = int(np.ceil(max(1, K) if K > 1 else max(1, K * sorted_ret.numel())))
        else:
            K = int(max(1, K))
        K = min(K, sorted_ret.numel())
        mean_topk = sorted_ret[:K].mean()
        return max_ret, mean_topk

    def _run_once(self, model, target_rew):
        model.eval()
        if hasattr(model, "init_eval"):
            model.init_eval()

        per_csv_scores = []
        overall_score = 0.0

        current_target_return = float(target_rew)

        for csv_path in self.test_csv_list:
            data_loader = TestDataLoader(file_path=csv_path)
            env = OfflineEnv()
            keys = data_loader.keys

            csv_score = 0.0

            for key in keys:
                gc.collect()
                if hasattr(model, "init_eval"):
                    model.init_eval()

                (num_t, pVals, pSigmas, lwc, budget, cpa, category) = data_loader.mock_data(key)
                remaining_budget = float(budget)

                rewards = np.zeros(num_t, dtype=np.float32)
                history = {
                    'historyBids': [],
                    'historyAuctionResult': [],
                    'historyImpressionResult': [],
                    'historyLeastWinningCost': [],
                    'historyPValueInfo': []
                }

                last_alpha = None

                for t in range(num_t):
                    pValue = pVals[t]; pSigma = pSigmas[t]; lwc_t = lwc[t]

                    if remaining_budget < env.min_remaining_budget:
                        bid = np.zeros(pValue.shape[0], dtype=np.float32)
                        pre_reward = sum(history["historyImpressionResult"][-1][:, 1]) if len(history["historyImpressionResult"]) > 0 else None
                    else:
                        state_vec = self._build_state_vec(
                            t, budget, remaining_budget,
                            pValue, pSigma,
                            history["historyPValueInfo"], history["historyBids"],
                            history["historyAuctionResult"], history["historyImpressionResult"],
                            history["historyLeastWinningCost"]
                        )
                        pre_reward = sum(history["historyImpressionResult"][-1][:, 1]) if len(history["historyImpressionResult"]) > 0 else None

                        if self.is_stitch and t > 0 and hasattr(model, "get_return"):
                            try:
                                state_t = torch.from_numpy(state_vec).to(self.device, dtype=torch.float32).unsqueeze(0)
                                if last_alpha is None:
                                    last_action_t = torch.zeros(1, 1, device=self.device, dtype=torch.float32)
                                else:
                                    last_action_t = torch.tensor([[last_alpha]], device=self.device, dtype=torch.float32)


                                sample_returns = model.get_return(state_t, last_action_t)
                                max_ret, mean_topk = self._topk_reduce(sample_returns)
                                if max_ret is not None:
                                    compare_value = max_ret
                                    if compare_value.item() > current_target_return:
                                        history = {
                                            'historyBids': [],
                                            'historyAuctionResult': [],
                                            'historyImpressionResult': [],
                                            'historyLeastWinningCost': [],
                                            'historyPValueInfo': []
                                        }
                                        last_reward = (pre_reward if pre_reward is not None else 0.0)
                                        current_target_return = float(compare_value.item() - last_reward / self.scale)
                            except Exception:
                                pass

                        with torch.no_grad():
                            alpha = model.take_actions(
                                state_vec,
                                target_return=current_target_return,
                                pre_reward=pre_reward
                            )
                        try:
                            last_alpha = float(alpha)
                        except Exception:
                            last_alpha = float(alpha.detach().cpu().item()) if torch.is_tensor(alpha) else float(alpha)

                        bid = last_alpha * pValue

                    tick_value, tick_cost, tick_status, tick_conv = env.simulate_ad_bidding(pValue, pSigma, bid, lwc_t)

                    over = max((np.sum(tick_cost) - remaining_budget) / (np.sum(tick_cost) + 1e-4), 0)
                    while over > 0:
                        idx = np.where(tick_status == 1)[0]
                        if len(idx) == 0:
                            break
                        drop = np.random.choice(idx, int(np.ceil(idx.shape[0] * over)), replace=False)
                        bid[drop] = 0
                        tick_value, tick_cost, tick_status, tick_conv = env.simulate_ad_bidding(pValue, pSigma, bid, lwc_t)
                        over = max((np.sum(tick_cost) - remaining_budget) / (np.sum(tick_cost) + 1e-4), 0)

                    remaining_budget -= float(np.sum(tick_cost))
                    rewards[t] = float(np.sum(tick_conv))

                    history["historyPValueInfo"].append(np.array([(pValue[i], pSigma[i]) for i in range(pValue.shape[0])]))
                    history["historyBids"].append(bid)
                    history["historyLeastWinningCost"].append(lwc_t)
                    history["historyAuctionResult"].append(np.array([(tick_status[i], tick_status[i], tick_cost[i]) for i in range(tick_status.shape[0])]))
                    history["historyImpressionResult"].append(np.array([(tick_conv[i], tick_conv[i]) for i in range(pValue.shape[0])]))

                total_reward = float(np.sum(rewards))
                total_cost   = float(budget - remaining_budget)
                cpa_real     = total_cost / (total_reward + 1e-10)
                score        = getScore_neurips(total_reward, cpa_real, cpa)
                csv_score   += score

            csv_score = csv_score / len(keys)
            per_csv_scores.append(csv_score)
            overall_score  += csv_score
            del data_loader, env, keys
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print(f"[EVAL]{os.path.basename(csv_path)} Period Score: {csv_score:.6f}")

        mean_score = overall_score / len(self.test_csv_list)
        print(f"[EVAL] Overall mean score across {len(per_csv_scores)} CSVs: {mean_score:.6f}")
        return mean_score, 0

    def eval_fn(self, target_rew):
        def fn(model):
            rets, lens = [], []
            for _ in range(self.num_eval_episodes):
                ret, length = self._run_once(model, target_rew)
                rets.append(ret); lens.append(length)
            return {
                f"target_{target_rew}_return_mean": float(np.mean(rets)),
                f"target_{target_rew}_return_std": float(np.std(rets)),
            }
        return fn
