from pprint import pprint
import json
from pathlib import Path
import os
import os.path as osp
import csv
from filelock import FileLock
import yaml
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from utils.hyper import get_src_model_dir, get_output_dir
from utils.seed import fix_seed
from utils.evaluation import evaluate_net_reg_Metric_logger, evaluation_logPrint
from utils.my_loggers import TxtLogger, set_logger
from model import create_regressor
from dataset import get_datasets





import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", '--config_source_yaml_dir', required=True, help="path")
    parser.add_argument("--seed", type=int, default=argparse.SUPPRESS)
    parser.add_argument("--dset", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--train_aug_type", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--aug_type", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--source", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--target", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--task_info", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--tr_val_split", type=float, default=argparse.SUPPRESS)
    parser.add_argument("--data_dir", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--model_dir", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--src_model_name", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--main_dir", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--exp_name", type=str, default=argparse.SUPPRESS)
    parser.add_argument("--key_info", type=str, default=argparse.SUPPRESS)

    args = parser.parse_args()


    with open(args.config_source_yaml_dir, "r", encoding="utf-8") as f:
        if args.config_source_yaml_dir.endswith(".json"):
            config = json.load(f)
        else:
            config = yaml.safe_load(f)

    args_dict = vars(args)
    for key, value in args_dict.items():
        config[key] = value

    main(config)
    


def main(config):
    fix_seed(config['seed'])
    config['config_save_name'] = config['key_info'] + "_config.yaml"
    
    
    src_model_dir, _ = get_src_model_dir(config)
    output_files_dir = get_output_dir(config)
    config['output_files_dir'] = output_files_dir
    pprint(config)
    
    Path(src_model_dir).mkdir(parents=True, exist_ok=True)
    Path(output_files_dir).mkdir(parents=True, exist_ok=True)
    with Path(output_files_dir, config['config_save_name']).open("w", encoding="utf-8") as f:
        yaml.dump(config, f)
    
    # ============= Logger ==================
    log_file = set_logger(config['key_info'])
    log_path = Path(output_files_dir) / log_file
    txt_logger, fh, sh = TxtLogger(filename=log_path)

    # ============= Model ================
    txt_logger.info(f"===== Init Model Loading =====")
    regressor = create_regressor(config, dropout=False).cuda()
    

    # ============= DataLoader ================
    txt_logger.info(f"===== Data Preparing =====")
    train_ds, _, val_ds, _ = get_datasets(config, config['source'])
    _, _, tar_val_ds, tar_val_aug_ds = get_datasets(config, config['target'])
    tr_bs = config["batch_size"]
    val_bs = 4 * config["batch_size"]
    train_dl = DataLoader(train_ds, batch_size=tr_bs, shuffle=True)  
    val_dl = DataLoader(val_ds, batch_size=val_bs, shuffle=False) 
    tar_val_dl = DataLoader(tar_val_ds, batch_size=val_bs, shuffle=False)  
    tar_val_aug_dl = DataLoader(tar_val_aug_ds, batch_size=val_bs, shuffle=False)


    # ============= training process ================
    txt_logger.info(f"===== Training Begins =====")
    opt = create_optimizer(regressor, config)
    step = 0
    
    regressor.eval()
    
    info_str = f"Before Training | SOURCE Val"
    metric_logger_dict = evaluate_net_reg_Metric_logger(regressor, val_dl, info_str)
    evaluation_logPrint(metric_logger_dict, info_str)
    
    info_str = f"Before Training | TARGET Val"
    metric_logger_dict = evaluate_net_reg_Metric_logger(regressor, tar_val_dl, info_str)
    evaluation_logPrint(metric_logger_dict, info_str)

    
    info_str = f"Before Training | TARGET Val AUG ({config['aug_type']})"
    metric_logger_dict = evaluate_net_reg_Metric_logger(regressor, tar_val_aug_dl, info_str)
    evaluation_logPrint(metric_logger_dict, info_str)
    
        
    for rd in range(config['epoch']):


        for batch in train_dl:
            regressor.train()
            opt.zero_grad()
            step += 1
            x, y, _ = batch
            x = x.cuda()
            y = y.float().flatten().cuda()

            y_pred = regressor(x)

            loss = F.mse_loss(y_pred, y)

            loss.backward()
            opt.step()
        
        regressor.eval()
        # source validation - no aug
        info_str = f"source train epoch [{rd}] - val on source"
        src_val_metric_logger_dict = evaluate_net_reg_Metric_logger(regressor, val_dl, info_str)
        evaluation_logPrint(src_val_metric_logger_dict, info_str)
        # target validation - no aug
        info_str = f"source train epoch [{rd}] - val on target"
        tar_val_metric_logger_dict = evaluate_net_reg_Metric_logger(regressor, tar_val_dl, info_str)
        evaluation_logPrint(tar_val_metric_logger_dict, info_str)
        # target validation - aug
        info_str = f"source train epoch [{rd}] - aug val on target ({config['aug_type']})"
        tar_aug_metric_logger_dict = evaluate_net_reg_Metric_logger(regressor, tar_val_aug_dl, info_str)
        evaluation_logPrint(tar_aug_metric_logger_dict, info_str)    
    
    
    model_path = osp.join(osp.abspath(src_model_dir), f"source_model_{step}.pt")
    torch.save(regressor.state_dict(), model_path)
    txt_logger.info(f"End of training, Model saved into {model_path}")
        
    # final print
    src_val_reg = src_val_metric_logger_dict['reg_metrics']
    tar_val_reg = tar_val_metric_logger_dict['reg_metrics']
    tar_aug_reg = tar_aug_metric_logger_dict['reg_metrics']
    final_info = f"""
End of source training.
seed, epoch, dataset, task_info, src_domain, test_domain, aug_type, mae, rmse, r2, R
{config['seed']}, {rd}, {config['dset']}, {config.get('task_info', 'NA')}, {config['source']}, {config['source']}, val, {src_val_reg['mae']:.4f}, {src_val_reg['rmse']:.4f}, {src_val_reg['r2']:.4f}, {src_val_reg['R']:.4f}
{config['seed']}, {rd}, {config['dset']}, {config.get('task_info', 'NA')}, {config['source']}, {config['target']}, val, {tar_val_reg['mae']:.4f}, {tar_val_reg['rmse']:.4f}, {tar_val_reg['r2']:.4f}, {tar_val_reg['R']:.4f}
{config['seed']}, {rd}, {config['dset']}, {config.get('task_info', 'NA')}, {config['source']}, {config['target']}, {config["aug_type"]}, {tar_aug_reg['mae']:.4f}, {tar_aug_reg['rmse']:.4f}, {tar_aug_reg['r2']:.4f}, {tar_aug_reg['R']:.4f}
"""
    txt_logger.info(final_info)
    write_experiment_result_csv(config, rd, src_val_reg, tar_val_reg, tar_aug_reg, file_path=f"{config['model_dir']}/{config['dset']}_src_tr_all.csv")
        

    txt_logger.removeHandler(fh)
    txt_logger.removeHandler(sh)

def create_optimizer(net, config):
    param = net.parameters()
    opt = eval(f"torch.optim.{config['optimizer']['name']}")(
        param, **config["optimizer"]["config"])
    return opt


def write_experiment_result_csv(config, rd, src_val_reg, tar_val_reg, tar_aug_reg, file_path="final_metrics.csv"):
    rows = [
        [
            config['seed'], rd, config['dset'], config.get('task_info', 'NA'),
            f"{config['source']}", f"{config['source']}", "val",
            src_val_reg['mae'].item(), src_val_reg['rmse'].item(), src_val_reg['r2'].item(), src_val_reg['R'].item()
        ],
        [
            config['seed'], rd, config['dset'], config.get('task_info', 'NA'),
            f"{config['source']}", f"{config['target']}", "val",
            tar_val_reg['mae'].item(), tar_val_reg['rmse'].item(), tar_val_reg['r2'].item(), tar_val_reg['R'].item()
        ],
        [
            config['seed'], rd, config['dset'], config.get('task_info', 'NA'),
            f"{config['source']}", f"{config['target']}", config['aug_type'],
            tar_aug_reg['mae'].item(), tar_aug_reg['rmse'].item(), tar_aug_reg['r2'].item(), tar_aug_reg['R'].item()
        ]
    ]

    header = ["seed", "epoch", "dataset", "task_info", "src_domain", "test_domain", "aug_type", "mae", "rmse", "r2", "R"]
    lock = FileLock(file_path + ".lock")

    with lock:
        file_exists = os.path.exists(file_path)
        with open(file_path, "a", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            if not file_exists:
                writer.writerow(header)
            writer.writerows(rows)

if __name__ == "__main__":
    parse_args()




