from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import mixgate
import torch
import os
from config import get_parse_args
import mixgate.top_model
import mixgate.top_trainer 
import torch.distributed as dist


DATA_DIR = './data'

if __name__ == '__main__':
    args = get_parse_args()

    circuit_path ='/root/datasets/merged_mixgate1%75000.npz'
   
    num_epochs = args.num_epochs
    
    print('[INFO] Parse Dataset')
    dataset = mixgate.NpzParser_Pair(DATA_DIR, circuit_path)

    train_dataset, val_dataset = dataset.get_dataset()

    print('[INFO] Create Model and Trainer')
    model = mixgate.top_model.TopModel(
        args, 
        dg_ckpt_aig='./ckpt/model_func_aig.pth',
        dg_ckpt_xag='./ckpt/model_func_xag.pth',
        dg_ckpt_xmg='./ckpt/model_func_xmg.pth',
        dg_ckpt_mig='./ckpt/model_func_mig.pth'
    )

    trainer = mixgate.top_trainer.TopTrainer(args, model, distributed=True)
    
    print('[INFO] Stage 1 Training (60 epochs) ...')
    trainer.set_training_args(lr=1e-4, lr_step=50, loss_weight=[1.0, 0.0, 0.0, 0.25])
    trainer.train(60, train_dataset, val_dataset)
    
    trainer.save(os.path.join(trainer.log_dir, 'stage1_prob_only_model.pth'))
    
    print('[INFO] Loading Stage 1 Checkpoint...')
    trainer.load(os.path.join(trainer.log_dir, 'stage1_prob_only_model.pth'))

    print('[INFO] Stage 2 Training (60 epochs) ...')
    trainer.set_training_args(lr=1e-4, lr_step=50, loss_weight=[1.0, 0.0, 1.0, 0.25])
    trainer.train(60, train_dataset, val_dataset)
    
    trainer.save(os.path.join(trainer.log_dir, 'stage2_prob_func_model.pth'))
    
    print('[INFO] Loading Stage 2 Checkpoint...')
    trainer.load(os.path.join(trainer.log_dir, 'stage2_prob_func_model.pth'))

    print('[INFO] Stage 3 Training (60 epochs) ...')
    trainer.set_training_args(lr=1e-4, lr_step=50, loss_weight=[1.0, 0.25, 1.0, 0.1])
    trainer.train(60, train_dataset, val_dataset)
    
    # 保存最终模型
    trainer.save(os.path.join(trainer.log_dir, 'stage3_final_model.pth'))
    
    print('[INFO] Training completed!')

    