from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import importlib
import argparse

import _init_paths
import torch 
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from model.scorenet import get_score_net
from loss.multi_loss import MultiLossFactory, MultiLossFactoryQuantile
from utils.rescore import train_core, train_test_core
from utils.rescore import valid_core
from utils.rescore import read_rescore_data

def parse_args():
    parser = argparse.ArgumentParser(description='Train rescore network')
    # general
    parser.add_argument('--hpe_dataset',
                        help='dataset name, please choosing from coco or crowdpose',
                        required=True,
                        type=str)
    parser.add_argument('--hpe_model',
                        help='model name',
                        required=True,
                        type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    return args

def import_cfg(args):
    dataset, model = args.hpe_dataset, args.hpe_model
    cfg_lib = f"{dataset}.{model}"
    cfg = importlib.import_module(cfg_lib)
    return cfg
  
def create_save_path(cfg):
    save_path = os.path.join('output', cfg.train['save_path'])
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    return save_path

# Linear warmup function
def lr_lambda(current_step):
    warmup_steps = 250
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return 1.0  # Keep LR constant after warmup

def main():
    args = parse_args()
    cfg = import_cfg(args)

    # create data
    x_train, y_train, pose_train, s_train, shape_train = read_rescore_data(args, cfg)
    x_test, y_test, pose_test, s_test, shape_test = read_rescore_data(args, cfg, is_train=False)

    # create model
    model = get_score_net(cfg, input_channel=cfg.model['feature_channels'], is_train=True).cuda()

    # create loss
    train_loss_fn = MultiLossFactoryQuantile(cfg).cuda()
    test_loss_fn = MultiLossFactory(cfg).cuda()

    # creat optimizer
    optimizer = optim.Adam(model.parameters(), lr=cfg.optimizer['lr'])
    scheduler = LambdaLR(optimizer, lr_lambda)

    # start training
    train_l1_losses, train_order_losses = [], []
    save_path = create_save_path(cfg)
    best_loss = 1e6
    for epoch in range(cfg.train['epochs']):
        # start training
        # train_l1_loss, train_order_loss, order_acc = train_core(x_train, x_test, y_train, s_train, 
        #                                                         optimizer, scheduler, model, loss_fn, 
        #                                                         cfg.train['batch_size'], cfg.train['pose_size'])
        train_l1_loss, train_order_loss, quantile_loss, order_acc = train_test_core(x_train, x_test, y_train, s_train, shape_train, shape_test, 
                                                                               optimizer, scheduler, model, train_loss_fn, 
                                                                               cfg.train['batch_size'], cfg.train['pose_size'])
        train_l1_losses.append(train_l1_loss)
        train_order_losses.append(train_order_loss)
        print("step:", epoch+1, "train_l1_loss:", train_l1_loss, "train_order_loss:", train_order_loss, "quantile_loss:", quantile_loss, "train_order_acc:", order_acc)

        # start testing
        test_l1_loss, test_order_loss, order_acc = valid_core(x_test, y_test, s_test, model, test_loss_fn, cfg.test['batch_size'], cfg.test['pose_size'])
        print("step:", epoch+1, "test_l1_loss:", test_l1_loss, "test_order_loss:", test_order_loss, "test_order_acc:", order_acc)
        if test_order_loss < best_loss:
           torch.save(model.state_dict(), os.path.join(save_path, f'model_best.pth'))
           best_loss = test_order_loss

        if epoch % 5 == 0:
            torch.save(model.state_dict(), os.path.join(save_path, f'model_{epoch}.pth'))
    torch.save(model.state_dict(), os.path.join(save_path, 'model_final.pth'))

if __name__ == '__main__':
    main()
