# eval_only.py (with TensorBoard)
import argparse, os, random, yaml, pickle, numpy as np, torch, time
from tensorboardX import SummaryWriter
from network import DecisionTransformer, TIT_DecisionTransformer
from evaluation_bidding import Evaluation

def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--algo', type=str, default='pdit')
    p.add_argument('--env', type=str, default='hopper_medium')
    p.add_argument('--model_path', type=str, required=True)
    p.add_argument('--device', type=str, default='cuda')
    p.add_argument('--seeds', type=str, default='1')
    p.add_argument('--test_csv_list', type=str, default='')
    p.add_argument('--normalize_dict_path', type=str, default='')
    p.add_argument('--logdir', type=str, default='./log_eval')
    args = p.parse_args()

    with open('config/default.yaml', 'r') as f:
        config = yaml.safe_load(f)
    with open(f'config/env/{args.env}.yaml', 'r') as f:
        config.update(yaml.safe_load(f))
    with open(f'config/algo/{args.algo}.yaml', 'r') as f:
        config.update(yaml.safe_load(f))

    config['device'] = args.device

    norm_path = args.normalize_dict_path or config.get('normalize_dict_path', '')
    if norm_path:
        with open(norm_path, 'rb') as f:
            nd = pickle.load(f)
        config['state_mean'] = np.array(nd['state_mean'], dtype=np.float32)
        config['state_std']  = np.array(nd['state_std'],  dtype=np.float32)

    if args.test_csv_list.strip():
        config['test_csv_list'] = [x.strip() for x in args.test_csv_list.split(',') if x.strip()]
    elif 'test_csv_list' not in config:
        config['test_csv_list'] = ['./data/traffic_test7/period-7.csv']

    if config.get('tit', False):
        model = TIT_DecisionTransformer(config).to(config['device'])
    else:
        model = DecisionTransformer(config).to(config['device'])

    model = torch.load(args.model_path, map_location=config['device'], weights_only=False)
    model.eval()

    state_mean = model.state_mean.detach().cpu().numpy()
    state_std  = model.state_std.detach().cpu().numpy()

    evaluator = Evaluation(config, state_mean=state_mean, state_std=state_std)

    # TensorBoard
    run_name = time.strftime('%Y%m%d-%H%M%S')
    tb_dir = os.path.join(args.logdir, args.algo, args.env, run_name)
    os.makedirs(tb_dir, exist_ok=True)
    writer = SummaryWriter(tb_dir)
    writer.add_text('meta/model_path', args.model_path)
    writer.add_text('meta/config', yaml.dump(config))

    seeds = [int(s) for s in args.seeds.split(',') if s.strip()]
    targets = config.get('env_targets', [0.0])

    all_scores = []
    for si, sd in enumerate(seeds):
        set_seed(sd)
        total = {}
        for tar in targets:
            out = evaluator.eval_fn(tar)(model)  # dict: mean/std
            total.update(out)

            base = f'seed_{sd}/target_{tar}'
            writer.add_scalar(f'{base}/return_mean', out[f'target_{tar}_return_mean'], si)
            writer.add_scalar(f'{base}/return_std',  out[f'target_{tar}_return_std'],  si)

        all_scores.append(total)
        print(f'[SEED {sd}] {total}')

    keys = all_scores[0].keys()
    summary = {k: float(np.mean([d[k] for d in all_scores])) for k in keys}
    summary_std = {k: float(np.std([d[k] for d in all_scores])) for k in keys}

    for k in keys:
        writer.add_scalar(f'summary/{k}_mean', summary[k])
        writer.add_scalar(f'summary/{k}_std',  summary_std[k])

    writer.close()

    for k in keys:
        print(f'{k}: mean={summary[k]:.6f}  std={summary_std[k]:.6f}')
    print('Logs ->', tb_dir, '\nRun: tensorboard --logdir', args.logdir)

if __name__ == '__main__':
    main()
