import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import numpy as np
from bidding_train_env.common.utils import save_normalize_dict
from bidding_train_env.common.utils import normalize_state, normalize_reward, save_normalize_dict
from bidding_train_env.baseline.q_critic.replay_buffer import ReplayBuffer
from bidding_train_env.baseline.q_critic.iql_critic import IQL_Critic
import logging
import os
import torch
import random
import argparse
import pandas as pd
import ast


os.chdir(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter





current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d")

def getScore_nips(reward, cpa, cpa_constraint):
    beta = 2
    penalty = 1
    if cpa > cpa_constraint:
        coef = cpa_constraint / (cpa + 1e-10)
        penalty = pow(coef, beta)
    return penalty * reward, penalty


def run_iql_baselines(baseline_method='IQL_crticic', sparse_data=False, data_path=None):
    writer = SummaryWriter(f"results/{baseline_method}_{formatted_datetime}")
    print(f"results/{baseline_method}_{formatted_datetime}")
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s] [%(name)s] [%(filename)s(%(lineno)d)] [%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(__name__)

    train_model(baseline_method, logger, writer,  sparse_data, data_path)


def train_model(baseline_method='IQL_crticic',  logger=None, writer=None, sparse_data=False, data_path=None):
    logger = logger
    train_data_path = data_path
    training_data = pd.read_csv(train_data_path)
    
    def safe_literal_eval(val):
        if pd.isna(val):
            return val
        try:
            return ast.literal_eval(val)
        except (ValueError, SyntaxError):
            return val

    def append_budget_cpa_cols(df, b_col="budget", c_col="CPAConstraint", target_cols=("state","next_state")):
        df[b_col] = pd.to_numeric(df[b_col], errors="coerce")
        df[c_col] = pd.to_numeric(df[c_col], errors="coerce")
        def _is_nan(x):
            return x is None or (isinstance(x, float) and np.isnan(x))
        def _to_1d(x):
            if x is None or (isinstance(x, float) and np.isnan(x)):
                return []
            if isinstance(x, np.ndarray) or isinstance(x, (list, tuple)):
                arr = np.array(x, dtype=float).reshape(-1)  # (1,16)->(16,)
                return arr.tolist()
            return [float(x)]

        def _append(x, b, c):
            if _is_nan(x):
                return x
            base = _to_1d(x)
            b = 0.0 if pd.isna(b) else float(b)
            c = 0.0 if pd.isna(c) else float(c)
            return tuple(base + [b, c])

        for col in target_cols:
            df[col] = [ _append(x, b, c) for x, b, c in zip(df[col], df[b_col], df[c_col]) ]
        return df


    def safe_literal_eval(val):
        if pd.isna(val):
            return val
        try:
            return ast.literal_eval(val)
        except (ValueError, SyntaxError):
            print(ValueError)
            return val

    training_data["state"] = training_data["state"].apply(safe_literal_eval)
    training_data["next_state"] = training_data["next_state"].apply(safe_literal_eval)
    is_normalize = True

    training_data = append_budget_cpa_cols(training_data, b_col="budget", c_col="CPAConstraint",target_cols=("state","next_state"))

 
    training_data["cpa_real"] = training_data["realAllCost"] / (training_data["realAllConversion"] + 1e-10)
    beta = 2.0
    cpa_constraint = training_data["CPAConstraint"]
    penalty_mask = training_data["cpa_real"] > cpa_constraint
    
    penalties = pd.Series(np.ones(len(training_data)), index=training_data.index)

    coef = cpa_constraint[penalty_mask] / (training_data.loc[penalty_mask, "cpa_real"] + 1e-10)
    penalties[penalty_mask] = np.power(coef, beta)
    raw_reward = training_data["reward_continuous"]
    

    training_data["reward_penalized"] = raw_reward * penalties
    print(f"Max CPA: {training_data['cpa_real'].max()}")
    print(f"Mean Reward (Before): {raw_reward.mean()}")
    print(f"Mean Reward (After):  {training_data['reward_penalized'].mean()}")

    if is_normalize:
        normalize_dic = normalize_state(training_data, 18, normalize_indices=[13, 14, 15, 16, 17])
        # select use penalized reward
        training_data['reward'] = normalize_reward(training_data, "reward_penalized")
        # select use raw reward
        # training_data['reward'] = normalize_reward(training_data, "reward")
        save_normalize_dict(normalize_dic, f"saved_model/{baseline_method}")


    # # Build replay buffer
    replay_buffer = ReplayBuffer()
    add_to_replay_buffer(replay_buffer, training_data, is_normalize)
    print(len(replay_buffer.memory))
    a_min = float(training_data["action"].min())
    a_max = float(training_data["action"].max())
    print("dataset action range:", a_min, a_max)  
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = IQL_Critic(state_dim=18, act_dim=1  )
    model.to(device)

    # Gradient steps and Batch size
    # step_num = 400000
    # batch_size = 512
    step_num = 10
    batch_size = 32

    model.train()

    for i in range(step_num):
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        if not torch.is_tensor(states):
            states = torch.tensor(states, dtype=torch.float32)
            actions = torch.tensor(actions, dtype=torch.float32)
            rewards = torch.tensor(rewards, dtype=torch.float32)
            next_states = torch.tensor(next_states, dtype=torch.float32)
            dones = torch.tensor(dones, dtype=torch.float32)

        states = states.to(device)
        actions = actions.to(device)
        rewards = rewards.to(device)
        next_states = next_states.to(device)
        dones = dones.to(device)

        loss_1, loss_2, loss_value = model.step(states=states, actions=actions, rewards=rewards, next_states=next_states,dones=dones)
        if i % 1000 == 0:
            logger.info(f"Step: {i} loss_critic1: {np.mean(loss_1)}  loss_critic2: {np.mean(loss_2)} loss_value: {np.mean(loss_value)}")
    
    model.save_net(f"saved_model/{baseline_method}")

def add_to_replay_buffer(replay_buffer, training_data, is_normalize):
    S = np.asarray(training_data["normalize_state" if is_normalize else "state"].tolist(), dtype=np.float32)
    NS = np.asarray(training_data["normalize_nextstate" if is_normalize else "next_state"].tolist(), dtype=np.float32)
    A = training_data["action"].to_numpy(dtype=np.float32).reshape(-1,1)
    R = training_data["normalize_reward" if is_normalize else "reward"].to_numpy(dtype=np.float32).reshape(-1,1)
    D = training_data["done"].to_numpy(dtype=np.float32).reshape(-1,1)
    RealCost = training_data["realAllCost"].to_numpy(dtype=np.float32).reshape(-1,1)
    RealCoversion = training_data["realAllConversion"].to_numpy(dtype=np.float32).reshape(-1,1)
    CPAConstraint = training_data["CPAConstraint"].to_numpy(dtype=np.float32).reshape(-1,1)
 
    zero_ns = np.zeros_like(S[0])
    for s,a,r,ns,d,rcost,rconversion,CPAConstraint in zip(S,A,R,NS,D,RealCost,RealCoversion,CPAConstraint):
        if d != 1:
            replay_buffer.push(s, a, r, ns, d,rcost,rconversion,CPAConstraint)
        else:
            replay_buffer.push(s, a, r, zero_ns, d,rcost,rconversion,CPAConstraint)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='training IQL Critics...')
    current_dir = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(current_dir, "../../data/trajectory/trajectory_data.csv")
    parser.add_argument('--baseline_method', type=str, default='IQL_crticic', help='choose a method to run')
    parser.add_argument('--use_history_rtg', type=bool, default=False, help='whether use Rtg to predict')
    parser.add_argument('--reweight_w', type=float, default=-1, help='for dt_reweight baseline: condition = rtg + w * ctg, -1 for jt_score')
    parser.add_argument('--sparse_data', type=bool, default=True, help='whether use the final stage data of AIGB competition (AuctionNet-sparse)')
    parser.add_argument('--data_path', type=str,default=data_path, help='path to load the dataset')

    args = parser.parse_args()

    run_iql_baselines(baseline_method=args.baseline_method, sparse_data=args.sparse_data, data_path=args.data_path)