import torch
import torch.nn as nn
from tqdm import tqdm
from easydict import EasyDict as edict

from src.models.utils import cal_3d_loss
from src.data_loader.utils import get_data, get_train_val_split
from src.data_loader.data_set import Data_Set
from src.models.port_model import peclr_to_torchvision
from src.models.rn_25D_wMLPref import RN_25D_wMLPref
import argparse

from src.constants import (
    COMET_KWARGS,
    HYBRID2_CONFIG,
    BASE_DIR,
    TRAINING_CONFIG_PATH,
)
from src.utils import get_console_logger, read_json

def run_single_epoch(model, data_loader, optimizer=None, isGPU=True, scheduler=None):
    device = torch.device("cuda:0" if torch.cuda.is_available() and isGPU else "cpu")

    total_loss = 0    
    if optimizer:
        model.train()
    else:
        model.eval()
    model = model.to(device)

    for batch in tqdm(data_loader, total = len(data_loader)):
        image = batch['image'].to(device)
        K = batch['K'].to(device)
        
        prediction = model(image, K)['kp25d']

        joint3d_gt = batch['joints3D'].to(device)
        scale = batch['scale'].to(device)
        joints_valid = batch['joints_valid'].to(device)

        loss = cal_3d_loss(prediction, joint3d_gt, scale, K, joints_valid)
        total_loss += loss.item()

        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return total_loss

def train_model(model, train_loader, val_loader, lr, epoch, do_train=True, do_eval=True, print_log=True, isGPU = True, save_dir=None, epoch_prefix=""):
    optimizer = torch.optim.Adam(model.parameters(), lr)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.98 ** epoch)
    best_val_loss = 1e10
    train_loss_list = []
    val_loss_list = []

    for epoch in range(1, epoch + 1):
        if do_train:
            train_loss = run_single_epoch(model, train_loader, optimizer=optimizer, isGPU = isGPU)
        else:
            train_loss = 0
        if do_eval :
            with torch.no_grad():
                val_loss = run_single_epoch(model, val_loader, optimizer=None, isGPU = isGPU)
        else:
            val_loss = 0
        print(f'epoch {epoch} train_loss: {train_loss}, val_loss: {val_loss}')
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)
        if best_val_loss > val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f'{epoch_prefix}_epoch_{epoch}.pth')
            print('save model')
        # scheduler.step()

    print("train_loss_list")
    print(train_loss_list)
    print("val_loss_list")
    print(val_loss_list)

def main(model_name):
    BATCH_SIZE = 128
    NUM_WORKERS = 4
    PRETRAINED_MODEL_PATH = f"F:\\peclr\\{model_name}"

    train_param = edict(read_json(TRAINING_CONFIG_PATH))
    train_param['train_ratio'] = 0.8
    # data preperation
    data = get_data(
        Data_Set, edict(train_param), sources=["freihand"], experiment_type="supervised"
    )
    train_data_loader, val_data_loader = get_train_val_split(
        data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
    )

    model = RN_25D_wMLPref(backend_model='rn152')

    # Load pretrained model
    peclr_to_torchvision(model.backend_model, PRETRAINED_MODEL_PATH)
    model.backend_model.requires_grad_(False)
    model.backend_model.fc.requires_grad_(True)

    train_model(model, train_data_loader, val_data_loader, 5e-4, 100, epoch_prefix=model_name[:-5])
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='fine tuning')
    parser.add_argument('--pretrained', required=True, help='pretrained_model.ckpt')
    args = parser.parse_args()
    main(args.pretrained)