import os
import time
import math
import pickle
import argparse
import numpy as np
import torch

from model import GPTConfig, GPT
from evaluator import evaluate_model  # 👈 导入共享评估函数

# -------------------------------
# 参数解析（新增 eval 相关参数）
# -------------------------------
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='graph')
parser.add_argument('--type', type=str, default='simple')
parser.add_argument('--n_layer', type=int, default=1)
parser.add_argument('--n_head', type=int, default=1)
parser.add_argument('--n_embd', type=int, default=120)
parser.add_argument('--max_iters', type=int, default=10000)
parser.add_argument('--num_nodes', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--fix_att', action='store_true')
parser.add_argument('--dtype', type=str, default='bfloat16', choices=['float32', 'float16', 'bfloat16'])
parser.add_argument('--lr', type=float, default=3e-4)

parser.add_argument('--save_suffix', type=str, default='')
parser.add_argument('--load_ckpt_path', type=str, default=None)
parser.add_argument('--save_interval', type=int, default=1000)

# 👇 新增评估参数
parser.add_argument('--eval_interval', type=int, default=200)        # 每多少步评估一次
parser.add_argument('--eval_temperature', type=float, default=0.00001)  # 采样温度
parser.add_argument('--eval_batch_size', type=int, default=500)     # 评估 batch_size
args = parser.parse_args()

# -------------------------------
# 设备 & 精度设置
# -------------------------------
device = 'cuda'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype) if args.dtype != 'float32' else nullcontext()
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))

# -------------------------------
# 路径 & 数据设置
# -------------------------------
data_dir = os.path.join('data', f'{args.dataset}/{args.num_nodes}_1_1')
meta_path = os.path.join(data_dir, f'{args.type}_meta.pkl')

with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
block_size = meta['block_size']
vocab_size = meta['vocab_size']

att_type = 'fix_att' if args.fix_att else 'unfix_att'
out_dir = f'pretrained_out/{att_type}/{args.dataset}/{args.type}/{args.n_layer}_{args.n_head}_{args.n_embd}_{args.num_nodes}_1_1' + args.save_suffix
os.makedirs(out_dir, exist_ok=True)

# 日志文件
log_file = os.path.join(out_dir, "train.log")

# -------------------------------
# 数据加载器
# -------------------------------
train_data = np.memmap(os.path.join(data_dir, f'{args.type}_train.bin'), dtype=np.uint16, mode='r')

def get_batch():
    data_size = block_size + 1
    ix = torch.randint((len(train_data) - data_size) // data_size, (args.batch_size,)) * data_size
    x = torch.stack([torch.from_numpy(train_data[i:i+block_size].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(train_data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
    return x.to(device), y.to(device)

# -------------------------------
# 模型初始化
# -------------------------------
model_args = dict(
    n_layer=args.n_layer,
    n_head=args.n_head,
    n_embd=args.n_embd,
    block_size=block_size,
    vocab_size=vocab_size,
    bias=False,
    dropout=0.0
)
model = GPT(GPTConfig(**model_args), fix_att=args.fix_att).to(device)

if args.load_ckpt_path is not None:
    print(f"Loading checkpoint from {args.load_ckpt_path}")
    model.load_state_dict(torch.load(args.load_ckpt_path, map_location=device)['model'])

# -------------------------------
# 优化器 & 学习率调度
# -------------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-1, betas=(0.9, 0.95))
warmup_iters = args.max_iters // 20
lr_decay_iters = args.max_iters
min_lr = 5e-5

def get_lr(it):
    return args.lr
    if it < warmup_iters:
        return 5e-4 * it / warmup_iters
    if it > lr_decay_iters:
        return min_lr
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (5e-4 - min_lr)

# -------------------------------
# 训练循环
# -------------------------------
model.train()
X, Y = get_batch()
best_val_loss = float('inf')

for iter_num in range(args.max_iters + 1):
    
    if iter_num % args.eval_interval == 0:
        for eval_type_data in ['simple_train', 'simple_test']:
            class EvalArgs:
                def __init__(self, **kwargs):
                    self.__dict__.update(kwargs)

            eval_args = EvalArgs(
                dataset=args.dataset,
                data_dir=f"{args.num_nodes}_1_1",
                type_data=eval_type_data,
                test_num=None,
                batch_size=args.eval_batch_size,
                temperature=args.eval_temperature,
                write_result=False,
                fix_att=args.fix_att,
                result_name=None,
                out_dir=out_dir,
                ckpt_iter=iter_num,  # 用于命名，非加载
            )

            accuracy = evaluate_model(model, eval_args, device, log_file=log_file, step=iter_num)
            print(f"[{iter_num}] Eval Accuracy on {eval_type_data}: {accuracy:.4f}")
            with open(log_file, 'a') as f:
                f.write(f"[{iter_num}] Eval Accuracy on {eval_type_data}: {accuracy:.4f}\n")
            
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    with ctx:
        logits, loss = model(X, Y, train_mode='pretrain')

    scaler.scale(loss).backward() if args.dtype == 'float16' else loss.backward()

    if args.dtype == 'float16':
        scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    if args.dtype == 'float16':
        scaler.step(optimizer)
        scaler.update()
    else:
        optimizer.step()

    optimizer.zero_grad()

    X, Y = get_batch()

    # 定期验证 loss
    if iter_num % args.save_interval == 0:
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for _ in range(5):
                X_val, Y_val = get_batch()
                with ctx:
                    _, loss_val = model(X_val, Y_val, train_mode='pretrain')
                val_loss += loss_val.item()
        val_loss /= 5
        model.train()

        print(f"iter {iter_num}: train loss {loss.item():.4f}, val loss {val_loss:.4f}, lr {lr:.6f}")

        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'model_args': model_args,
            'iter_num': iter_num,
        }, os.path.join(out_dir, f'{iter_num}_ckpt.pt'))

        with open(log_file, 'a') as f:
            f.write(f"iter {iter_num}: train loss {loss.item():.4f}, val loss {val_loss:.4f}\n")

    # 常规日志
    if iter_num % 200 == 0:
        print(f"iter {iter_num}: loss {loss.item():.4f}, lr {lr:.6f}, dtype={args.dtype}")
        with open(log_file, 'a') as f:
            f.write(f"iter {iter_num}: loss {loss.item():.4f}\n")