from typing import Any
from pprint import pprint
import json
from pathlib import Path
import argparse

import yaml
import ast
import torch
from torch.utils.data import DataLoader

from utils.hyper import get_src_model_dir, get_output_dir
from utils.seed import fix_seed
from utils.sampling import get_sampling
from utils.my_loggers import TxtLogger, set_logger
from model import create_regressor, CNNRegressorClassifier_DROPOUT, MLPRegressorClassifier_DROPOUT
from dataset import get_datasets
from utils.hist_utils import RegValueHist
from engines import train_hist
from utils.optim_utils import create_optimizers_schedulers_main

def update_nested_dict(d, keys, value):
    for key in keys[:-1]:
        d = d.setdefault(key, {})
    d[keys[-1]] = value

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", '--config_source_yaml_dir', required=True, help="..")
    parser.add_argument("--seed", type=int, default=argparse.SUPPRESS)
    parser.add_argument("--dset", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--train_aug_type", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--aug_type", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--source", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--target", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--task_info", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--tr_val_split", type=float, default=argparse.SUPPRESS)
    parser.add_argument("--data_dir", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--model_dir", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--src_model_name", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--main_dir", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--exp_name", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--key_info", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--update", type=str, nargs='+', help="List of key-value pairs to update nested config, e.g., --update ssa.adapt.weight_exp 1.5 optimizer.config.lr 0.001")

    args = parser.parse_args()

    with open(args.config_source_yaml_dir, "r", encoding="utf-8") as f:
        if args.config_source_yaml_dir.endswith(".json"):
            config = json.load(f)
        else:
            config = yaml.safe_load(f)

    args_dict = vars(args)
    for key, value in args_dict.items():
        if key != "update":  
            config[key] = value
    
    if args.update:
        updates = args.update
        for i in range(0, len(updates), 2):
            key, value = updates[i], updates[i + 1]
            keys = key.split('.')  
            try:
                value = ast.literal_eval(value)
            except:
                pass
            update_nested_dict(config, keys, value)

    main(config)


def main(config):
    fix_seed(config['seed'])
    config['config_save_name'] = config['key_info'] + "_config.yaml"
    config["cls"]["num_classes"] = []
    
    src_model_dir, src_model_path = get_src_model_dir(config)
    output_files_dir = get_output_dir(config)
    config['output_files_dir'] = output_files_dir
    
    pprint(config)
    
    Path(output_files_dir).mkdir(parents=True, exist_ok=True)
    with Path(output_files_dir, config['config_save_name']).open("w", encoding="utf-8") as f:
        yaml.dump(config, f)
    
    # ============= Logger ==================
    log_file = set_logger(config['key_info'])
    log_path = Path(output_files_dir) / log_file
    txt_logger, fh, sh = TxtLogger(filename=log_path)
    


    # ============= Model ================
    txt_logger.info(f"===== Init Model Loading =====")
    # DO
    regressor_do = create_regressor(config, dropout=True).cuda()
    regressor_do.load_state_dict(
        torch.load(src_model_path), strict=False)
    txt_logger.info(f"Sampling Regressor (do) loaded: {src_model_path}")
    # Ref
    regressor_ref = create_regressor(config, dropout=False).cuda()
    regressor_ref.load_state_dict(
        torch.load(src_model_path), strict=False)
    txt_logger.info(f"Refenrence Source Regressor loaded: {src_model_path}")
    

    # ============= DataLoader ================
    # target dataloader
    txt_logger.info(f"===== Data Preparing =====")
    _, __, target_ds, target_aug_ds = get_datasets(config, config['target'])
    tr_bs = config["batch_size"]
    val_bs = 4 * config["batch_size"]
    target_aug_dl = DataLoader(target_aug_ds, batch_size=tr_bs, shuffle=True)  
    target_dl = DataLoader(target_ds, batch_size=val_bs, shuffle=False)  


    # ============= training process ================
    cls_regressor = None
    txt_logger.info(f"===== Training Begins =====")
    
    # 1. sampling
    txt_logger.info(f"===== 1. Sampling =====")
    sample_dict = get_sampling(
        regressor_do, target_dl, config['dropout'])

    txt_logger.info(f"===== 2. y_hist Construction =====")
    y_hist = RegValueHist(sample_dict, **config["cls"]["config"])
    config["cls"]["num_classes"].append(y_hist.num_classes)
    with Path(output_files_dir, config['config_save_name']).open("w", encoding="utf-8") as f:
        yaml.dump(config, f)
    
    # model preparation
    regressor = create_regressor(config, dropout=False).cuda()
    reg_model_param = torch.load(src_model_path)
    regressor.load_state_dict(reg_model_param, strict=False)
    match data_type := config['data_type']:
        case 'image':
            cls_regressor = CNNRegressorClassifier_DROPOUT(
                base_model=regressor,
                num_classes=y_hist.num_classes,
                backbone=config["net_backbone"],
            ).cuda()
        case 'table':
            net_info = config['net_info']
            cls_regressor = MLPRegressorClassifier_DROPOUT(
                net_info[0], net_info[1], net_info[2],
                base_model=regressor,
                num_classes=y_hist.num_classes,
            ).cuda()
    txt_logger.info(f"CLS_REG model initialized with class num [{y_hist.num_classes}]")
    
    # optimizer preperation
    opt_dict, bn_mode_dict, scheduler_dict, module_state_dict = create_optimizers_schedulers_main(
    cls_regressor, config, "optimizer")
        

    # 3. main training
    txt_logger.info(f"===== 3. Main Adaptation =====\n")
    model_config = {
        'net': cls_regressor,
        'net_ref': regressor_ref,
    }
    opt_config = {
        'opt': opt_dict,
        'schedulers': scheduler_dict,
        'module_state': module_state_dict,
        'bn_mode': bn_mode_dict
    }
    dl_config = {
        'tr_dl': target_aug_dl,
        'tr_dl_name': f"AUG_{config['aug_type']}",
        'val_dl': target_dl,
        'val_dl_name': "VAL",
    }
    others = {
        "y_hist": y_hist,
        "train_config": config["train"],
    }
    train_hist(**model_config, **opt_config, **dl_config, **others)

    txt_logger.removeHandler(fh)
    txt_logger.removeHandler(sh)


if __name__ == "__main__":
    parse_args()
