import yaml
import argparse
from pathlib import Path
from copy import deepcopy
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from get_model import Model
from baselines import ERM, DIR, LRIGaussian, MixUp, GroupDRO, LRIBern, VREx
from utils import *
from training_procedure import *
import warnings
warnings.filterwarnings("ignore")


def write_res(res_path, args, report):
    write_log(str(args), res_path)
    write_log(f'seed {args.seed}', res_path)
    for item in report.keys():
        write_log(f'{item}: {report[item]}', res_path)


def train(config, method_name, model_name, seed, dataset_name, log_dir, device, shift_type):
    batch_size = config['optimizer']['batch_size']
    epochs = config[method_name]['epochs']
    data_config = config['data']
    shift_config = config[shift_type]
    loaders, Dataset = get_data_loaders(dataset_name, batch_size, data_config, seed, shift_config=shift_config)
    np.set_printoptions(threshold=np.inf)
    torch.set_printoptions(profile="full")

    clf = Model(model_name, config['model'][model_name], method_name, config[method_name], Dataset).to(device)
    extractor = ExtractorMLP(config['model'][model_name]['hidden_size'], config[method_name],
                             config['data'].get('use_lig_info', False)).to(device)

    criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
    optimizer = get_wp_optimizer(clf, config['optimizer'])

    if method_name == 'erm':
        baseline = ERM(clf, criterion)
        metric_dict_auc, metric_dict_ood_auc, metric_dict_acc, metric_dict_ood_acc = run_and_log(
            log_dir, epochs, baseline, optimizer, loaders, seed, data_config, method_name)

    elif method_name == 'dir':
        extractor = ExtractorMLP(config['model'][model_name]['hidden_size'] * 2, config[method_name],
                                 config['data'].get('use_lig_info', False)).to(device)
        baseline = DIR(clf, extractor, criterion, config['dir'])
        optimizers = get_dir_optimizer(clf, extractor, config['optimizer'], config['dir'])

        metric_dict_auc, metric_dict_ood_auc, metric_dict_acc, metric_dict_ood_acc = run_and_log(
            log_dir, epochs, baseline, optimizers, loaders, seed, data_config, method_name)

    elif method_name == 'lri_bern':
        optimizer = get_optimizer(clf, extractor, config['optimizer'], config[method_name], warmup=False)
        baseline = LRIBern(clf, extractor, criterion, config['lri_bern'])

        metric_dict_auc, metric_dict_ood_auc, metric_dict_acc, metric_dict_ood_acc = run_and_log(
            log_dir, epochs, baseline, optimizer, loaders, seed, data_config, method_name)

    elif method_name == 'mixup':
        baseline = MixUp(clf, criterion, config['mixup'])
        metric_dict_auc, metric_dict_ood_auc, metric_dict_acc, metric_dict_ood_acc = run_and_log(
            log_dir, epochs, baseline, optimizer, loaders, seed, data_config, method_name)
    elif method_name == 'groupdro':
        criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
        baseline = GroupDRO(clf, criterion, config['groupdro'])
        metric_dict_auc, metric_dict_ood_auc, metric_dict_acc, metric_dict_ood_acc = run_and_log(
            log_dir, epochs, baseline, optimizer, loaders, seed, data_config, method_name)

    elif method_name == 'VREx':
        criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
        baseline = VREx(clf, criterion, config['VREx'])
        metric_dict_auc, metric_dict_ood_auc, metric_dict_acc, metric_dict_ood_acc = run_and_log(
            log_dir, epochs, baseline, optimizer, loaders, seed, data_config, method_name)

    report_dict = {k.replace('metric/best_', ''): v for k, v in metric_dict_auc.items()}
    report_dict_ood = {k.replace('metric/best_', ''): v for k, v in metric_dict_ood_auc.items()}
    report_dict_acc = {k.replace('metric/best_', ''): v for k, v in metric_dict_acc.items()}
    report_dict_ood_acc = {k.replace('metric/best_', ''): v for k, v in metric_dict_ood_acc.items()}
    return report_dict, report_dict_ood, report_dict_acc, report_dict_ood_acc


def main():
    parser = argparse.ArgumentParser(description='HEP OoD dataset')
    parser.add_argument('-d', '--dataset', type=str, help='dataset used', default='tau3mu')
    parser.add_argument('-m', '--method', type=str, help='method used', default='erm')
    parser.add_argument('--shift', type=str, help='shift type', default='scaffold')  # option: pileup, signal, scaffold
    parser.add_argument('-b', '--backbone', type=str, help='backbone used', default='egnn')
    parser.add_argument('--cuda', type=int, help='cuda device id, -1 for cpu', default=1)
    parser.add_argument('--seed', type=int, help='random seed', default=0)
    parser.add_argument('--kr', type=float, default=5)
    parser.add_argument('--target', type=str, help='target domain info', default=50)
    parser.add_argument('--setting', type=str, help='option: OOD, DA, TL', default='OOD')
    parser.add_argument('--note', type=str, help='note in log name', default='')
    parser.add_argument('--coeff', type=float, default=0.1)
    args = parser.parse_args()

    set_seed(args.seed)
    config_name = args.dataset
    config_path = Path('./configs') / f'{config_name}.yml'
    config = yaml.safe_load(config_path.open('r'))
    torch.autograd.set_detect_anomaly(True)
    if config[args.method].get(args.backbone, False):
        config[args.method].update(config[args.method][args.backbone])
    device = torch.device(f'cuda:{args.cuda}' if args.cuda >= 0 else 'cpu')

    config[args.shift]["setting"] = args.setting
    config[args.method]['coeff'] = args.coeff

    log_dir = Path(
        config['data'][
            'data_dir']) / config_name / args.method / f'{args.backbone}_{args.shift}_shift_{args.note}_{args.target}_{args.setting}_{args.coeff}'
    log_dir.mkdir(parents=True, exist_ok=True)

    report_dict, report_dict_ood, report_dict_acc, report_dict_ood_acc = train(config, args.method, args.backbone,
                                                                               args.seed, args.dataset, log_dir, device,
                                                                               args.shift)

    result_log_path = log_dir / 'result_auc.txt'
    result_ood_log_path = log_dir / 'result_ood_auc.txt'
    result_log_path_acc = log_dir / 'result_acc.txt'
    result_ood_log_path_acc = log_dir / 'result_ood_acc.txt'

    write_res(result_log_path, args, report_dict)
    write_res(result_ood_log_path, args, report_dict_ood)
    write_res(result_log_path_acc, args, report_dict_acc)
    write_res(result_ood_log_path_acc, args, report_dict_ood_acc)


if __name__ == '__main__':
    main()
