from utils import *
import torch
from tqdm import tqdm


def eval_one_batch(baseline, optimizer, data, epoch, method_name, multi_optim=False, phase='iid_valid'):
    assert optimizer is None
    baseline.extractor.eval() if hasattr(baseline, 'extractor') else None
    baseline.domain_adv.eval() if hasattr(baseline, 'domain_adv') else None
    baseline.clf.eval()
    _, loss_dict, org_clf_logits = baseline.forward_pass(data, epoch, phase)
    return loss_dict, to_cpu(org_clf_logits)


def train_one_batch(baseline, optimizers, data, epoch, method_name, multi_optim=False, phase='train'):
    baseline.clf.train()
    if hasattr(baseline, 'extractor'):
        baseline.extractor.train()

    loss, loss_dict, org_clf_logits = baseline.forward_pass(data, epoch, "train")
    if not multi_optim:
        optimizers.zero_grad()
        loss.backward()
        optimizers.step()
    else:
        for i in range(len(loss)):
            optimizers[i].zero_grad()
            loss[i].backward()
            optimizers[i].step()
    return loss_dict, to_cpu(org_clf_logits)


def train_one_batch_DA(baseline, optimizers, data_s, data_t, epoch, multi_optim=False, phase='train'):
    baseline.clf.train()
    if hasattr(baseline, 'domain_adv'):
        baseline.domain_adv.train()

    loss, loss_dict, org_clf_logits = baseline.forward_pass((data_s, data_t), epoch, "train")
    if not multi_optim:
        optimizers.zero_grad()
        loss.backward()
        optimizers.step()
    else:
        for i in range(len(loss)):
            optimizers[i].zero_grad()
            loss[i].backward()
            optimizers[i].step()
    return loss_dict, to_cpu(org_clf_logits)


def run_one_epoch(baseline, optimizer, data_loader, epoch, phase, seed, method_name):
    loader_len = len(data_loader)
    run_one_batch = train_one_batch if phase == 'train' else eval_one_batch
    phase = 'test ' if phase == 'test' else phase  # align tqdm desc bar
    log_dict = {'org_clf_logits': [], 'clf_labels': []}
    all_loss_dict = {}
    multi_optim = True if method_name in ['dir'] else False
    pbar = tqdm(data_loader)
    for idx, data in enumerate(pbar):
        loss_dict, org_clf_logits = run_one_batch(baseline, optimizer, data.to(baseline.device), epoch, method_name,
                                                  multi_optim, phase)
        clf_labels = to_cpu(data.y)
        for key in log_dict.keys():
            log_dict[key].append(eval(key))

        for k, v in loss_dict.items():
            all_loss_dict[k] = all_loss_dict.get(k, 0) + v

        if idx == loader_len - 1:
            for k, v in all_loss_dict.items():
                # avg_loss in an epoch
                all_loss_dict[k] = v / loader_len
            desc, org_clf_acc, org_clf_auc, avg_loss = log_epoch_v2(epoch, phase, all_loss_dict, log_dict, seed,
                                                                    batch=False)
    return org_clf_acc, org_clf_auc, all_loss_dict, avg_loss


def run_one_epoch_DA(baseline, optimizer, data_loader, epoch, phase, seed, method_name, iters_per_epoch=0):
    log_dict = {'org_clf_logits': [], 'clf_labels': []}
    all_loss_dict = {}
    if iters_per_epoch > 0:
        # training process
        train_source_iter, train_target_iter = data_loader
        pbar = tqdm(range(iters_per_epoch))
        for i in pbar:
            data_s = next(train_source_iter).to(baseline.device)
            data_t = next(train_target_iter).to(baseline.device)

            loss_dict, org_clf_logits = train_one_batch_DA(baseline, optimizer, data_s, data_t, epoch, False, phase)
            if len(data_s.y.shape) == 1:
                data_s.y = data_s.y.unsqueeze(1)
            clf_labels = to_cpu(data_s.y)
            for key in log_dict.keys():
                log_dict[key].append(eval(key))

            for k, v in loss_dict.items():
                all_loss_dict[k] = all_loss_dict.get(k, 0) + v

            if i == iters_per_epoch - 1:
                for k, v in all_loss_dict.items():
                    all_loss_dict[k] = v / iters_per_epoch
                desc, org_clf_acc, org_clf_auc, avg_loss = log_epoch_v2(epoch, phase, all_loss_dict, log_dict, seed,
                                                                        batch=False)
        return org_clf_acc, org_clf_auc, all_loss_dict, avg_loss

    else:
        loader_len = len(data_loader)
        pbar = tqdm(data_loader)
        for idx, data in enumerate(pbar):
            loss_dict, org_clf_logits = eval_one_batch(baseline, optimizer, data.to(baseline.device), epoch,
                                                       method_name,
                                                       False, phase)
            clf_labels = to_cpu(data.y)
            for key in log_dict.keys():
                log_dict[key].append(eval(key))

            for k, v in loss_dict.items():
                all_loss_dict[k] = all_loss_dict.get(k, 0) + v

            if idx == loader_len - 1:
                for k, v in all_loss_dict.items():
                    all_loss_dict[k] = v / loader_len
                desc, org_clf_acc, org_clf_auc, avg_loss = log_epoch_v2(epoch, phase, all_loss_dict, log_dict, seed,
                                                                        batch=False)
        return org_clf_acc, org_clf_auc, all_loss_dict, avg_loss


def run_and_log(log_dir, epochs, baseline, optimizer, loaders, seed, data_config, method_name, run_ood=True):
    file_path = log_dir / 'logging_auc.txt'
    file_path_ood = log_dir / 'logging_ood_auc.txt'
    file_path_acc = log_dir / 'logging_acc.txt'
    file_path_ood_acc = log_dir / 'logging_ood_acc.txt'
    file_loss = log_dir / 'logging_train_loss.txt'

    metric_dict = deepcopy(init_metric_dict_ood_auc)
    metric_dict_acc = deepcopy(init_metric_dict_ood_acc)
    if run_ood:
        metric_dict_ood = deepcopy(init_metric_dict_ood_auc)
        metric_dict_ood_acc = deepcopy(init_metric_dict_ood_acc)
    val = "val" if not run_ood else "iid_val"
    test = "test" if not run_ood else "iid_test"
    for epoch in range(epochs):
        train_res = run_one_epoch(baseline, optimizer, loaders['train'], epoch, 'train', seed, method_name)
        valid_res = run_one_epoch(baseline, None, loaders[val], epoch, 'iid_valid', seed, method_name)
        test_res = run_one_epoch(baseline, None, loaders[test], epoch, 'iid_test', seed, method_name)
        if run_ood:
            ood_valid_res = run_one_epoch(baseline, None, loaders['ood_val'], epoch, 'ood_valid', seed, method_name)
            ood_test_res = run_one_epoch(baseline, None, loaders['ood_test'], epoch, 'ood_test', seed, method_name)
        # write training loss
        write_log(f'epoch {epoch} ' + f"{' '.join([f'{k} {v}' for k, v in train_res[2].items()])}", log_file=file_loss)
        write_log(
            f'epoch {epoch} val_loss {valid_res[-1]} test_loss {test_res[-1]} train_auc {train_res[1]} val_auc {valid_res[1]} test_auc {test_res[1]}',
            log_file=file_path)
        write_log(
            f'epoch {epoch} val_loss {valid_res[-1]} test_loss {test_res[-1]} train_acc {train_res[0]} val_acc {valid_res[0]} test_acc {test_res[0]}',
            log_file=file_path_acc)
        metric_dict = update_and_save_best_epoch_res_auc(seed, baseline, train_res, valid_res, test_res, metric_dict,
                                                         epoch, file_path, log_dir)

        metric_dict_acc = update_and_save_best_epoch_res_acc(seed, baseline, train_res, valid_res, test_res,
                                                             metric_dict_acc, epoch, file_path_acc, log_dir)

        if run_ood:
            write_log(
                f'epoch {epoch} ood_val_loss {ood_valid_res[-1]} ood_test_loss {ood_test_res[-1]} train_auc {train_res[1]} ood_val_auc {ood_valid_res[1]} ood_test_auc {ood_test_res[1]}',
                log_file=file_path_ood)
            write_log(
                f'epoch {epoch} ood_val_loss {ood_valid_res[-1]} ood_test_loss {ood_test_res[-1]} train_acc {train_res[0]} ood_val_acc {ood_valid_res[0]} ood_test_acc {ood_test_res[0]}',
                log_file=file_path_ood_acc)
            metric_dict_ood = update_and_save_best_epoch_res_auc(seed, baseline, train_res, ood_valid_res, ood_test_res,
                                                                 metric_dict_ood, epoch,
                                                                 file_path_ood, log_dir, is_ood=True)
            metric_dict_ood_acc = update_and_save_best_epoch_res_acc(seed, baseline, train_res, ood_valid_res,
                                                                     ood_test_res,
                                                                     metric_dict_ood_acc, epoch,
                                                                     file_path_ood_acc, log_dir, is_ood=True)
    if run_ood:
        return metric_dict, metric_dict_ood, metric_dict_acc, metric_dict_ood_acc
    else:
        return metric_dict, metric_dict_acc


def run_and_log_DA(log_dir, epochs, baseline, optimizer, loaders, seed, method_name, iters_per_epoch):
    file_path = log_dir / 'logging_auc.txt'
    file_path_ood = log_dir / 'logging_ood_auc.txt'
    file_path_acc = log_dir / 'logging_acc.txt'
    file_path_ood_acc = log_dir / 'logging_ood_acc.txt'
    file_loss = log_dir / 'logging_train_loss.txt'

    metric_dict = deepcopy(init_metric_dict_ood_auc)
    metric_dict_acc = deepcopy(init_metric_dict_ood_acc)
    metric_dict_ood = deepcopy(init_metric_dict_ood_auc)
    metric_dict_ood_acc = deepcopy(init_metric_dict_ood_acc)

    for epoch in range(epochs):
        train_res = run_one_epoch_DA(baseline, optimizer, (loaders['train_source'], loaders['train_target']),
                                     epoch, 'train', seed, method_name, iters_per_epoch)
        valid_res = run_one_epoch_DA(baseline, None, loaders['iid_val'], epoch, 'iid_valid', seed, method_name)
        test_res = run_one_epoch_DA(baseline, None, loaders['iid_test'], epoch, 'iid_test', seed, method_name)
        ood_valid_res = run_one_epoch_DA(baseline, None, loaders['ood_val'], epoch, 'ood_valid', seed, method_name)
        ood_test_res = run_one_epoch_DA(baseline, None, loaders['ood_test'], epoch, 'ood_test', seed, method_name)
        # write training loss
        write_log(f'epoch {epoch} ' + f"{' '.join([f'{k} {v}' for k, v in train_res[2].items()])}", log_file=file_loss)
        write_log(
            f'epoch {epoch} val_loss {valid_res[-1]} test_loss {test_res[-1]} train_auc {train_res[1]} val_auc {valid_res[1]} test_auc {test_res[1]}',
            log_file=file_path)
        write_log(
            f'epoch {epoch} val_loss {valid_res[-1]} test_loss {test_res[-1]} train_acc {train_res[0]} val_acc {valid_res[0]} test_acc {test_res[0]}',
            log_file=file_path_acc)
        write_log(
            f'epoch {epoch} ood_val_loss {ood_valid_res[-1]} ood_test_loss {ood_test_res[-1]} train_auc {train_res[1]} ood_val_auc {ood_valid_res[1]} ood_test_auc {ood_test_res[1]}',
            log_file=file_path_ood)
        write_log(
            f'epoch {epoch} ood_val_loss {ood_valid_res[-1]} ood_test_loss {ood_test_res[-1]} train_acc {train_res[0]} ood_val_acc {ood_valid_res[0]} ood_test_acc {ood_test_res[0]}',
            log_file=file_path_ood_acc)

        metric_dict = update_and_save_best_epoch_res_auc(seed, baseline, train_res, valid_res, test_res, metric_dict,
                                                         epoch, file_path, log_dir)
        metric_dict_acc = update_and_save_best_epoch_res_acc(seed, baseline, train_res, valid_res, test_res,
                                                             metric_dict_acc, epoch, file_path_acc, log_dir)
        metric_dict_ood = update_and_save_best_epoch_res_auc(seed, baseline, train_res, ood_valid_res, ood_test_res,
                                                             metric_dict_ood, epoch,
                                                             file_path_ood, log_dir, is_ood=True)
        metric_dict_ood_acc = update_and_save_best_epoch_res_acc(seed, baseline, train_res, ood_valid_res,
                                                                 ood_test_res,
                                                                 metric_dict_ood_acc, epoch,
                                                                 file_path_ood_acc, log_dir, is_ood=True)

    return metric_dict, metric_dict_ood, metric_dict_acc, metric_dict_ood_acc

