from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import importlib
import argparse

import _init_paths
import torch 
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from model.scorenet import get_score_net
from loss.multi_loss import MultiLossFactory, MultiLossFactoryGaussian, MultiLossFactoryCRPS
from utils.rescore import train_core, train_test_core
from utils.rescore import valid_core
from utils.rescore import read_rescore_data

def parse_args():
    parser = argparse.ArgumentParser(description='Train rescore network')
    # general
    parser.add_argument('--hpe_dataset',
                        help='dataset name, please choosing from coco or crowdpose',
                        required=True,
                        type=str)
    parser.add_argument('--hpe_model',
                        help='model name',
                        required=True,
                        type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()
    return args

def import_cfg(args):
    dataset, model = args.hpe_dataset, args.hpe_model
    cfg_lib = f"{dataset}.{model}"
    cfg = importlib.import_module(cfg_lib)
    return cfg
  
def create_save_path(cfg):
    save_path = os.path.join('output', cfg.train['save_path'])
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    return save_path

# Linear warmup function
def lr_lambda(current_step):
    warmup_steps = 250
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return 1.0  # Keep LR constant after warmup

def main():
    args = parse_args()
    cfg = import_cfg(args)

    # create data
    x_train, y_train, pose_train, s_train, shape_train = read_rescore_data(args, cfg)
    x_test, y_test, pose_test, s_test, shape_test = read_rescore_data(args, cfg, is_train=False)
    print(x_train.shape)
    # create model
    model = get_score_net(cfg, input_channel=cfg.model['feature_channels'], is_train=True).cuda()
    # print(cfg.model['feature_channels'], x_train.shape)
    total_params = sum(p.numel() for p in model.parameters())            # all params
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)                           # only those updated by optimizers

    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")  

    # create loss
    # train_loss_fn = MultiLossFactory(cfg).cuda()
    # test_loss_fn = MultiLossFactory(cfg).cuda()
    # train_loss_fn = MultiLossFactoryGaussian(cfg).cuda()
    # test_loss_fn = MultiLossFactoryGaussian(cfg).cuda()

    train_loss_fn = MultiLossFactoryCRPS(cfg).cuda()
    test_loss_fn = MultiLossFactoryCRPS(cfg).cuda()

    # creat optimizer
    optimizer = optim.Adam(model.parameters(), lr=cfg.optimizer['lr'])
    scheduler = LambdaLR(optimizer, lr_lambda)

    # start training
    train_l1_losses, train_order_losses, train_acc = [], [], []
    test_l1_losses, test_order_losses, test_acc = [], [], []
    save_path = create_save_path(cfg)
    best_loss = 1e6
    for epoch in range(cfg.train['epochs']):
        # start training
        train_l1_loss, train_order_loss, order_acc = train_core(x_train, y_train, s_train, 
                                                                optimizer, scheduler, model, train_loss_fn, 
                                                                cfg.train['batch_size'], cfg.train['pose_size'])
        
        train_l1_losses.append(train_l1_loss)
        train_order_losses.append(train_order_loss)
        train_acc.append(order_acc)
        print("step:", epoch+1, "train_l1_loss:", train_l1_loss, "train_order_loss:", train_order_loss, "train_order_acc:", order_acc)

        # start testing
        test_l1_loss, test_order_loss, order_acc = valid_core(x_test, y_test, s_test, model, test_loss_fn, cfg.test['batch_size'], cfg.test['pose_size'])
        test_l1_losses.append(test_l1_loss)
        test_order_losses.append(test_order_loss)
        test_acc.append(order_acc)
        print("step:", epoch+1, "test_l1_loss:", test_l1_loss, "test_order_loss:", test_order_loss, "test_order_acc:", order_acc)
        if test_order_loss < best_loss:
           torch.save(model.state_dict(), os.path.join(save_path, f'model_best.pth'))
           best_loss = test_order_loss

        if epoch % 5 == 0:
            torch.save(model.state_dict(), os.path.join(save_path, f'model_{epoch}.pth'))
    torch.save(model.state_dict(), os.path.join(save_path, 'model_final.pth'))
    
    import pickle
    with open(os.path.join(save_path, 'loss_acc.pkl'), 'wb') as f:
        pickle.dump((train_l1_losses, train_order_losses, train_acc, test_l1_losses, test_order_losses, test_acc), f)
    # print("producing figures")
    # import matplotlib.pyplot as plt 
    # plt.figure()
    # # plt.plot([i for i in range(cfg.train['epochs'])], train_order_losses)
    # # plt.plot([i for i in range(cfg.train['epochs'])], train_acc)
    # plt.plot([i for i in range(cfg.train['epochs'])], test_order_losses)
    # plt.plot([i for i in range(cfg.train['epochs'])], test_acc)

    # plt.legend(['train order loss', 'train order acc.', 'test order loss', 'test order acc.'])
    # plt.show()
    # plt.savefig(os.path.join(save_path, 'acc_loss.png'))

if __name__ == '__main__':
    main()
