import os
import time
import math
import pickle
import argparse
import numpy as np
import torch
from contextlib import nullcontext
from evaluator import evaluate_model
from tqdm import tqdm

from model import GPTConfig, GPT
from pg_utils import get_m, get_m_td, get_scores_6  # 👈 请确保这些函数在 utils.py 中定义

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='graph')
parser.add_argument('--n_layer', type=int, default=1)
parser.add_argument('--n_head', type=int, default=1)
parser.add_argument('--n_embd', type=int, default=120)
parser.add_argument('--max_iters', type=int, default=10000)
parser.add_argument('--num_nodes', type=int, default=100)
parser.add_argument('--load_ckpt_num', type=int, default=10000)
parser.add_argument('--save_interval', type=int, default=2000)
parser.add_argument('--loss_func', type=str, default='reinforce_with_token_kl')
parser.add_argument('--score_func', type=int, default=6)
parser.add_argument('--fix_att', action='store_true')
parser.add_argument('--gen_temp', type=float, default=1.0)
parser.add_argument('--reach_type', type=str, default='true')
parser.add_argument('--adj_type', type=str, default='true')
parser.add_argument('--train_type', type=str, default='simple')
parser.add_argument('--kl_constant', type=float, default=0.0)
parser.add_argument('--reward_bias', type=float, default=0.0)
parser.add_argument('--dtype', type=str, default='bfloat16', choices=['float32', 'float16', 'bfloat16'])
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--lr', type=float, default=3e-5)

# 在参数解析部分，添加 test 相关默认参数（用于训练中调用）
parser.add_argument('--eval_interval', type=int, default=50)  # 👈 每50步评估一次
parser.add_argument('--eval_temperature', type=float, default=0.00001)
parser.add_argument('--eval_batch_size', type=int, default=500)
parser.add_argument('--eval_type_data', type=str, default='simple_test')  # 测试集类型
args = parser.parse_args()

# -------------------------------
# 设备 & 精度设置
# -------------------------------
device = args.device
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) if args.dtype != 'float32' else nullcontext()
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
batch_size = 512
# -------------------------------
# 路径 & 数据加载
# -------------------------------
data_dir = os.path.join('data', f'{args.dataset}/{args.num_nodes}_1_1')
meta_path = os.path.join(data_dir, f'{args.train_type}_meta.pkl')

with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
block_size = meta['block_size']
vocab_size = meta['vocab_size']
stoi, itos = meta['stoi'], meta['itos']

# 输出目录
att_type = 'fix_att' if args.fix_att else 'unfix_att'
out_dir = f'on_policy_out/{att_type}/{args.dataset}/{args.n_layer}_{args.n_head}_{args.n_embd}_{args.num_nodes}_from_{args.load_ckpt_num}_{args.score_func}_{args.reward_bias}_{args.gen_temp}_{args.adj_type}_{args.reach_type}_{args.loss_func}_{args.kl_constant}_{args.train_type}'
os.makedirs(out_dir, exist_ok=True)

# 加载训练数据（用于采样起始点）
train_start = np.memmap(os.path.join(data_dir, f'{args.train_type}_train.bin'), dtype=np.uint16, mode='r')
train_tensor = torch.from_numpy(train_start.astype(np.int64)).view(-1, block_size + 1).to(device)

# 加载图结构矩阵
true_adj = np.load(f'{data_dir}/{args.adj_type}_adj_matrix.npy')
true_reach = np.load(f'{data_dir}/{args.reach_type}_reach_matrix.npy')
for i in range(vocab_size - 2):
    true_reach[i][i] = 1

# 构建 mask 矩阵
if args.loss_func == 'TD':
    Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check = get_m_td(vocab_size, true_adj, true_reach)
else:
    Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check = get_m(vocab_size, true_adj, true_reach)

# 动态获取 score 函数
score_func_name = f"get_scores_{args.score_func}"
my_get_score = globals().get(score_func_name, get_scores_6)  # 默认 fallback

# -------------------------------
# 模型初始化
# -------------------------------
model_args = dict(
    n_layer=args.n_layer,
    n_head=args.n_head,
    n_embd=args.n_embd,
    block_size=block_size,
    vocab_size=vocab_size,
    bias=False,
    dropout=0.0
)
model = GPT(GPTConfig(**model_args), fix_att=args.fix_att).to(device)
old_model = GPT(GPTConfig(**model_args), fix_att=args.fix_att).to(device)

# 从预训练加载（可选）
if args.load_ckpt_num != -1:
    ckpt_path = f'./pretrained_out/{att_type}/{args.dataset}/simple/{args.n_layer}_{args.n_head}_{args.n_embd}_{args.num_nodes}_1_1/{args.load_ckpt_num}_ckpt.pt'
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model'])
    old_model.load_state_dict(ckpt['model'])
    old_model.eval()

# -------------------------------
# 优化器 & 学习率调度
# -------------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-1, betas=(0.9, 0.95))
warmup_iters = args.max_iters // 20
lr_decay_iters = args.max_iters
min_lr = 5e-5

def get_lr(it):
    return args.lr
    if it < warmup_iters:
        return 5e-4 * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (5e-4 - min_lr)

# -------------------------------
# 训练循环
# -------------------------------
log_file = os.path.join(out_dir, "train.log")

for iter_num in tqdm(range(args.max_iters + 1)):
    # 设置学习率

    if iter_num % args.eval_interval == 0:
        print(f"[{iter_num}] Running evaluation...")
        # 构造一个临时 args 用于评估
        class EvalArgs:
            def __init__(self, **kwargs):
                self.__dict__.update(kwargs)

        if args.train_type == 'aug':
            datas = ['train2train', 'train2test', 'test2train', 'test2test']
        elif args.train_type == 'simple':
            datas = ['simple_test']
            
        for eval_type_data in datas:
            eval_args = EvalArgs(
                dataset=args.dataset,
                data_dir=f"{args.num_nodes}_1_1",
                type_data=eval_type_data,
                test_num=None,
                batch_size=args.eval_batch_size,
                temperature=args.eval_temperature,
                write_result=False,
                fix_att=args.fix_att,
                result_name=None,
                out_dir=out_dir,  # 使用训练输出目录
            )

            accuracy = evaluate_model(model, eval_args, device, log_file=log_file, step=iter_num)
            print(f"[{iter_num}] Eval Accuracy on {eval_type_data}: {accuracy:.4f}")
            with open(log_file, 'a') as f:
                f.write(f"[{iter_num}] Eval Accuracy on {eval_type_data}: {accuracy:.4f}\n")

    # 定期保存
    if iter_num % args.save_interval == 0:
        print(f"[{iter_num}] Saving checkpoint...")
        with open(log_file, 'a') as f:
            f.write(f"iter {iter_num}: saving checkpoint\n")
        torch.save({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                }, os.path.join(out_dir, f'{iter_num}_ckpt.pt'))
                        
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # 生成样本
    with torch.no_grad():
        # 随机采样 batch_size 个起始点
        start_idx = torch.randperm(train_tensor.size(0))[:batch_size]
        batch_input = train_tensor[start_idx, :2]  # shape: [B, 2]

        # 生成完整路径
        generated = model.generate(batch_input, block_size - 1, temperature=args.gen_temp, top_k=vocab_size)

        # 获取旧模型 logits（用于 KL）
        old_logits = old_model.get_logits(generated[:, :block_size])

        # 计算 reward
        scores = my_get_score(generated, Start_Check, Adj_Check, Path_Check, Reach_Check, Pad_Check)
        if iter_num % 10 == 1:
            success_rate = (scores[:, 2] == scores.max(dim=1, keepdim=True)[0]).sum().item() / scores.shape[0]
            print(f"[{iter_num}] Success rate: {success_rate:.3f}")

        scores = (scores - args.reward_bias) * 1.0  # score_scale=1

    # 准备训练数据
    X = generated[:, :block_size]          # 输入序列
    Y = generated[:, 2:block_size + 1]     # 目标（偏移一位）
    Z = scores                             # reward

    # 前向 + 反向
    with ctx:
        logits, loss = model(X, Y, Z, train_mode=args.loss_func, trans_func='original', old_logits=old_logits, kl_constant=args.kl_constant)

    # 根据精度选择反向传播方式
    scaler.scale(loss).backward() if args.dtype == 'float16' else loss.backward()

    # 梯度裁剪
    if args.dtype == 'float16':
        scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # 更新参数
    if args.dtype == 'float16':
        scaler.step(optimizer)
        scaler.update()
    else:
        optimizer.step()

    optimizer.zero_grad()

    # 日志输出
    if iter_num % 100 == 0:
        print(f"[{iter_num}] loss: {loss.item():.4f}, lr: {lr:.6f}, dtype: {args.dtype}")
        with open(log_file, 'a') as f:
            f.write(f"iter {iter_num}: loss {loss.item():.4f}\n")
            
        