import logging
import torch
txt_logger = logging.getLogger("sfda_reg")


def evaluate_reg_mae(y_pred_reg, y_true_reg):
    return torch.mean(torch.abs(y_pred_reg - y_true_reg))


def evaluate_reg_rmse(y_pred_reg, y_true_reg):
    mse = torch.mean((y_true_reg - y_pred_reg)**2)  # MSE: Mean Squared Error
    # RMSE: Root Mean Squared Error
    return torch.sqrt(mse)


def evaluate_reg_r2(y_pred_reg, y_true_reg):
    mean_y_true = torch.mean(y_true_reg)
    tss = torch.sum((y_true_reg - mean_y_true)**2)
    rss = torch.sum((y_true_reg - y_pred_reg)**2)
    # R2
    r2 = 1 - (rss / tss)
    return r2  # float


def evaluate_reg_r(y_pred_reg, y_true_reg):
    y_true_reg = y_true_reg.view(-1)
    y_pred_reg = y_pred_reg.view(-1)

    y_true_mean = torch.mean(y_true_reg)
    y_pred_mean = torch.mean(y_pred_reg)

    # numerator - covariance
    numerator = torch.sum((y_true_reg - y_true_mean) * (y_pred_reg - y_pred_mean))

    # denominator - product of standard deviations
    std_true = torch.sqrt(torch.sum((y_true_reg - y_true_mean)**2))
    std_pred = torch.sqrt(torch.sum((y_pred_reg - y_pred_mean)**2))
    denominator = std_true * std_pred
    if denominator == 0:
        return torch.tensor(0.0)
    return numerator / denominator  #  pearson corr in [-1, 1]

def evaluate_REG(y_pred_reg, y_true_reg):
    mae = evaluate_reg_mae(y_pred_reg, y_true_reg)
    rmse = evaluate_reg_rmse(y_pred_reg, y_true_reg)
    r2 = evaluate_reg_r2(y_pred_reg, y_true_reg)
    r = evaluate_reg_r(y_pred_reg, y_true_reg)

    reg_dict = {'mae': mae, 'rmse': rmse, 'r2': r2, 'R': r}
    return reg_dict

@torch.no_grad()
def forward_for_info_onlyReg(net, val_dl):
    dataset_len = len(val_dl.dataset)

    y_pred_reg_bank = torch.zeros(dataset_len)
    y_true_reg_bank = torch.zeros(dataset_len)

    net.eval()
    for batch in val_dl:
        x, y, idx = batch
        x = x.cuda()
        feature = net.feature(x)
        y_pred_reg_batch = net.predict_from_feature(feature)
        
        y_pred_reg_bank[idx] = y_pred_reg_batch.detach().clone().cpu()
        y_true_reg_bank[idx] = y.float().flatten()
    return y_pred_reg_bank, y_true_reg_bank

def reg_eval_print_flexible(re_dict, prefix):
    other_keys = set(re_dict.keys()) - {'mae', 'rmse', 'r2', 'R'}
    re_str = f"{prefix} mae = {re_dict['mae']:4f} | rmse = {re_dict['rmse']:.4f} | r2 = {re_dict['r2']:.4f} | R = {re_dict['R']:.4f}"
    for k in other_keys:
        v = re_dict[k]
        re_str += f" | {k} = {v:.4f}"
    return re_str

def evaluation_logPrint(metric_logger_eval_dict, phase):
    log_str = f"""EVALUATION
[{phase}]"""
        
    if reg_metrics := metric_logger_eval_dict.get("reg_metrics", False):
        log_str += '\n' + reg_eval_print_flexible(reg_metrics, "[regressor]")
    
    txt_logger.info(log_str+'\n')
    return log_str

def evaluate_net_reg_Metric_logger(
        net, val_dl, info, return_net_pred=False):
    
    y_pred_reg, y_true_reg = forward_for_info_onlyReg(
        net, val_dl)

    reg_metrics = evaluate_REG(y_pred_reg, y_true_reg)

    metric_logger_dict = {
        "reg_metrics": reg_metrics,
        "other_metrics_dict": {
            'note':
            'regressor evaluate',
            'info': info
        }
    }

    if return_net_pred:
        net_pred_dict = {
            'y_pred_reg': y_pred_reg,
            'y_true_reg': y_true_reg,
        }
        return metric_logger_dict, net_pred_dict

    return metric_logger_dict

