import torch
import numpy as np
from datetime import datetime
from sklearn.metrics import roc_auc_score, mean_absolute_error
import torch.nn.functional as F
import time


def print_time_info(string, end='\n', dash_top=False, dash_bot=False, file=None):
    times = str(time.strftime('%Y-%m-%d %H:%M:%S',
                              time.localtime(time.time())))
    string = "[%s] %s" % (times, str(string))
    if dash_top:
        print(len(string) * '-', file=file)
    print(string, end=end, file=file)
    if dash_bot:
        print(len(string) * '-', file=file)


def write_log(print_str, log_file, print_=True):
    if print_:
        print_time_info(print_str)
    if log_file is None:
        return
    with open(log_file, 'a') as f:
        f.write('\n')
        f.write(print_str)


def to_item(tensor):
    if tensor is None:
        return None
    elif isinstance(tensor, torch.Tensor):
        return tensor.item()
    else:
        return tensor


def log_epoch_v2(epoch, phase, loss_dict, log_dict, seed, batch):
    desc = f'[Seed {seed}, Epoch: {epoch}]: {phase}....., ' if batch else f'[Seed {seed}, Epoch: {epoch}]: {phase} done, '
    for k, v in loss_dict.items():
        desc += f'{k}: {v:.3f}, '
    eval_desc, org_clf_acc, org_clf_auc = get_eval_score_v2(epoch, phase, log_dict, batch)
    desc += eval_desc
    return desc, org_clf_acc, org_clf_auc, loss_dict['pred']


def log_epoch_regress(epoch, phase, loss_dict, log_dict, seed, batch):
    desc = f'[Seed {seed}, Epoch: {epoch}]: {phase}....., ' if batch else f'[Seed {seed}, Epoch: {epoch}]: {phase} done, '
    for k, v in loss_dict.items():
        desc += f'{k}: {v:.3f}, '
    eval_desc, regress_mae = get_eval_score_regress(log_dict, batch)
    desc += eval_desc
    return desc, regress_mae, loss_dict['pred']


def get_eval_score_regress(log_dict, batch=False):
    record_dict = {}
    if batch:
        record_dict['regress_out'] = log_dict['regress_out'][-1]
        record_dict['regress_labels'] = log_dict['regress_labels'][-1]
    else:
        record_dict['regress_out'] = torch.cat(log_dict['regress_out'])
        record_dict['regress_labels'] = torch.cat(log_dict['regress_labels'])

    regress_mae = mean_absolute_error(record_dict['regress_labels'], record_dict['regress_out'])
    desc = f'mae: {regress_mae:.3f}'
    return desc, regress_mae


def get_eval_score_v2(epoch, phase, log_dict, batch):
    mul_class = False
    assert mul_class is False, 'multi-class not supported yet'
    record_dict = {}
    if batch:
        record_dict['org_clf_logits'] = log_dict['org_clf_logits'][-1]
        record_dict['clf_labels'] = log_dict['clf_labels'][-1]
    else:
        record_dict['org_clf_logits'] = torch.cat(log_dict['org_clf_logits'])
        record_dict['clf_labels'] = torch.cat(log_dict['clf_labels'])

    org_clf_preds = get_preds_from_logits(record_dict['org_clf_logits'])
    clf_labels = record_dict['clf_labels']
    org_clf_preds = org_clf_preds.reshape(clf_labels.shape)
    org_clf_acc = (org_clf_preds == clf_labels).sum().item() / clf_labels.shape[0]
    desc = f'org_acc: {org_clf_acc:.3f}, '

    org_clf_auc = None
    if not batch:
        org_clf_auc = roc_auc_score(clf_labels, record_dict['org_clf_logits'].sigmoid()) if not mul_class else 0
        desc += f'org_auc: {org_clf_auc:.3f}, '
    return desc, org_clf_acc, org_clf_auc


def get_preds_from_logits(logits):
    if logits.shape[1] > 1:  # multi-class
        preds = logits.argmax(dim=1).float()
    else:  # binary
        preds = (logits.sigmoid() > 0.5).float()
    return preds


def log(*args):
    print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]', *args)


def load_checkpoint(model, model_dir, model_name, map_location=None):
    checkpoint = torch.load(model_dir / (model_name + '.pt'), map_location=map_location)
    model.load_state_dict(checkpoint['model_state_dict'])


def save_checkpoint(model, model_dir, model_name):
    torch.save({'model_state_dict': model.state_dict()}, f'{model_dir}/{model_name}.pt')


def update_and_save_best_epoch_res_auc(seed, baseline, train_res, valid_res, test_res, metric_dict, epoch, file_path,
                                       model_dir, is_ood=False):
    auc_idx_orig = 1

    better_val_auc = valid_res[auc_idx_orig] > metric_dict['metric/best_clf_valid_auc']
    same_val_auc_but_better_val_loss = (valid_res[auc_idx_orig] == metric_dict['metric/best_clf_valid_auc']) and (
                valid_res[-1] < metric_dict['metric/best_clf_valid_loss'])

    metric_dict['metric/clf_train_loss'] = train_res[-1]
    metric_dict['metric/best_clf_train_auc'] = train_res[auc_idx_orig]
    if better_val_auc or same_val_auc_but_better_val_loss:
        metric_dict['metric/best_clf_epoch'] = epoch
        metric_dict['metric/best_clf_valid_loss'] = valid_res[-1]

        metric_dict['metric/best_clf_valid_auc'] = valid_res[auc_idx_orig]
        metric_dict['metric/best_clf_test_auc'] = test_res[auc_idx_orig]
        write_log(f'***Update in epoch {epoch}!***', log_file=file_path)
        if model_dir is not None:
            save_checkpoint(baseline, model_dir,
                            model_name=f'model_auc_{seed}' if not is_ood else f'model_ood_auc_{seed}')
    return metric_dict


def update_and_save_best_epoch_res_acc(seed, baseline, train_res, valid_res, test_res, metric_dict, epoch, file_path,
                                       model_dir, is_ood=False):
    acc_idx_orig = 0

    better_val_acc = valid_res[acc_idx_orig] > metric_dict['metric/best_clf_valid_acc']
    same_val_acc_but_better_val_loss = (valid_res[acc_idx_orig] == metric_dict['metric/best_clf_valid_acc']) and (
                valid_res[-1] < metric_dict['metric/best_clf_valid_loss'])

    metric_dict['metric/clf_train_loss'] = train_res[-1]
    metric_dict['metric/best_clf_train_acc'] = train_res[acc_idx_orig]
    if better_val_acc or same_val_acc_but_better_val_loss:
        metric_dict['metric/best_clf_epoch'] = epoch
        metric_dict['metric/best_clf_valid_loss'] = valid_res[-1]

        metric_dict['metric/best_clf_valid_acc'] = valid_res[acc_idx_orig]
        metric_dict['metric/best_clf_test_acc'] = test_res[acc_idx_orig]
        write_log(f'***Update in epoch {epoch}!***', log_file=file_path)
        if model_dir is not None:
            save_checkpoint(baseline, model_dir,
                            model_name=f'model_acc_{seed}' if not is_ood else f'model_ood_acc_{seed}')
    return metric_dict


def update_and_save_best_epoch_regress(seed, baseline, train_res, valid_res, test_res, metric_dict, epoch, file_path,
                                       model_dir):
    better_val_mae = valid_res[0] < metric_dict['metric/best_regrs_valid_mae']
    same_val_mae_but_better_val_loss = (valid_res[0] == metric_dict['metric/best_regrs_valid_mae']) and (
                valid_res[-1] < metric_dict['metric/best_regrs_valid_loss'])

    metric_dict['metric/regrs_train_loss'] = train_res[-1]
    metric_dict['metric/best_regrs_train_mae'] = train_res[0]
    if better_val_mae or same_val_mae_but_better_val_loss:
        metric_dict['metric/best_regrs_epoch'] = epoch
        metric_dict['metric/best_regrs_valid_loss'] = valid_res[-1]
        metric_dict['metric/best_regrs_valid_mae'] = valid_res[0]
        metric_dict['metric/best_regrs_test_mae'] = test_res[0]
        write_log(f'***Update in epoch {epoch}!***', log_file=file_path)
        if model_dir is not None:
            save_checkpoint(baseline, model_dir, model_name=f'regress_model_seed_{seed}')
    return metric_dict