# -*- coding: utf-8 -*-
import numpy as np
from typing import Dict

import torch
from torch import nn
from torch.distributions import Normal
import torch.nn.functional as F


def compute_quantile_loss_instance_wise(outputs: torch.Tensor,
                                        targets: torch.Tensor,
                                        desired_quantiles: torch.Tensor) -> torch.Tensor:
    """
    This function compute the quantile loss separately for each sample,time-step,quantile.

    Parameters
    ----------
    outputs: torch.Tensor
        The outputs of the model [num_samples x num_horizons x num_quantiles].
    targets: torch.Tensor
        The observed target for each horizon [num_samples x num_horizons].
    desired_quantiles: torch.Tensor
        A tensor representing the desired quantiles, of shape (num_quantiles,)

    Returns
    -------
    losses_array: torch.Tensor
        a tensor [num_samples x num_horizons x num_quantiles] containing the quantile loss for each sample,time-step and
        quantile.
    """

    # compute the actual error between the observed target and each predicted quantile
    errors = targets.unsqueeze(-1) - outputs
    # Dimensions:
    # errors: [num_samples x num_horizons x num_quantiles]

    # compute the loss separately for each sample,time-step,quantile
    losses_array = torch.max((desired_quantiles - 1) * errors, desired_quantiles * errors)
    # Dimensions:
    # losses_array: [num_samples x num_horizons x num_quantiles]

    return losses_array


def get_quantiles_loss_and_q_risk(outputs: Dict[str, Dict[str, torch.tensor]],
                                  targets: Dict[str, Dict[str, torch.tensor]],
                                  desired_quantiles: torch.Tensor) -> torch.Tensor:
    """
    This function computes quantile loss and q-risk metric.

    Parameters
    ----------
    outputs: torch.Tensor
        The outputs of the model [num_samples x num_horizons x num_quantiles].
    targets: torch.Tensor
        The observed target for each horizon [num_samples x num_horizons].
    desired_quantiles: torch.Tensor
        a tensor representing the desired quantiles, of shape (num_quantiles,).

    Returns
    ----------
    q_loss: torch.Tensor
        a scalar representing the quantile loss across all samples,horizons and quantiles.
    q_risk: torch.Tensor
        a tensor (shape=(num_quantiles,)) with q-risk metric for each quantile separately.
    losses_array: torch.Tensor
        a tensor [num_samples x num_horizons x num_quantiles] containing the quantile loss for each
        sample,time-step and quantile.

    """
    loss_list = []
    # risk_list = []
    for obj_type, obj_id_data_dict in targets.items():
        for obj_id, obj_data in obj_id_data_dict.items():
            targets_tensor = targets[obj_type][obj_id]
            outputs_tensor = outputs[obj_type][obj_id]
            losses_array = compute_quantile_loss_instance_wise(
                outputs=outputs_tensor, targets=targets_tensor, desired_quantiles=desired_quantiles
            )

            # sum losses over quantiles and average across time and observations
            q_loss = (losses_array.sum(dim=-1)).mean(dim=-1).mean()  # a scalar (shapeless tensor)

            # compute q_risk for each quantile
            # q_risk = 2 * (losses_array.sum(dim=1).sum(dim=0)) / (targets_tensor.abs().sum().unsqueeze(-1))
            loss_list.append(q_loss)
            # risk_list.append(q_risk)
    loss = torch.mean(torch.stack(loss_list), dim=0)
    # risk = torch.mean(torch.stack(risk_list), dim=0)
    return loss


def get_mse_loss(outputs: Dict[str, Dict[str, torch.tensor]], targets: Dict[str, Dict[str, torch.tensor]]
                 ) -> torch.Tensor:
    criterion = nn.MSELoss()
    loss_list = []
    for obj_type, obj_id_data_dict in targets.items():
        for obj_id, obj_data in obj_id_data_dict.items():
            targets_tensor = targets[obj_type][obj_id].unsqueeze(-1)
            outputs_tensor = outputs[obj_type][obj_id]
            for i in range(targets_tensor.size(2)):
                target_tensor = targets_tensor[:, :, i, :]
                output_tensor = outputs_tensor[:, :, i, :]
                loss = criterion(output_tensor, target_tensor)
                loss_list.append(loss)
    loss = torch.mean(torch.stack(loss_list), dim=0)
    return loss



def compute_mse_loss(predict: torch.Tensor, target: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
    """

    @param predict:
    @param target:
    @param reduction:
    @return:
    """
    criterion = nn.MSELoss(reduction=reduction)
    loss = criterion(predict, target)
    return loss


def compute_log_likelihood_loss(predict: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """

    @param predict:
    @param target:
    @return:
    """
    errors = target - predict
    sigma = errors.std(dim=1)
    dist = Normal(0, sigma.unsqueeze(1))
    log_likelihoods = torch.abs(dist.log_prob(errors).mean(dim=1))
    total_log_likelihood = log_likelihoods.mean()
    return total_log_likelihood


def compute_dtw_loss(predict: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """

    @param predict:
    @param target:
    @return:
    """
    T = predict.size(0)
    D = torch.zeros((T, T))
    for i in range(T):
        for j in range(T):
            D[i, j] = torch.abs(predict[i] - target[j])
    C = torch.zeros((T, T))
    C[0, 0] = D[0, 0]
    for i in range(1, T):
        C[i, 0] = C[i - 1, 0] + D[i, 0]
        C[0, i] = C[0, i - 1] + D[0, i]
    for i in range(1, T):
        for j in range(1, T):
            C[i, j] = D[i, j] + min(C[i - 1, j], C[i, j - 1], C[i - 1, j - 1])
    i, j = T - 1, T - 1
    while i > 0 or j > 0:
        if i == 0:
            j -= 1
        elif j == 0:
            i -= 1
        else:
            min_index = np.argmin([C[i - 1, j], C[i, j - 1], C[i - 1, j - 1]])
            if min_index == 0:
                i -= 1
            elif min_index == 1:
                j -= 1
            else:
                i -= 1
                j -= 1
    dtw_loss = C[T - 1, T - 1]
    return dtw_loss


def compute_huber_loss(predict: torch.Tensor, target: torch.Tensor, delta: float = 1.0) -> torch.Tensor:
    """

    :param target: 实际值
    :param predict: 预测值
    :param delta: 控制 L1 和 L2 损失之间的切换阈值，默认为 1.0
    :return: Huber 损失值
    """
    error = target - predict
    abs_error = torch.abs(error)
    delta = torch.tensor(delta, device=error.device, dtype=error.dtype)
    loss = torch.where(abs_error <= delta, 0.5 * error ** 2, delta * (abs_error - 0.5 * delta))
    mean_loss = torch.mean(loss)
    return mean_loss


def compute_FDS_loss(time_series_1: torch.Tensor, time_series_2: torch.Tensor, alpha: float = 0.5) -> torch.Tensor:
    """
    [-(1 - alpha), 1]
    :param time_series_1:
    :param time_series_2:
    :param alpha:
    :return:
    """
    time_series_1 = time_series_1.to(torch.float32)
    time_series_2 = time_series_2.to(torch.float32)
    fft_1 = torch.fft.fft(time_series_1)
    fft_2 = torch.fft.fft(time_series_2)
    mag_1, phase_1 = torch.abs(fft_1), torch.angle(fft_1)
    mag_2, phase_2 = torch.abs(fft_2), torch.angle(fft_2)
    mse_magnitude = torch.mean((mag_1 - mag_2) ** 2)
    max_mse = torch.mean(mag_1 ** 2)
    cosine_similarity_phase = torch.mean(torch.cos(phase_1 - phase_2))
    similarity = alpha * (1 - torch.clamp(mse_magnitude / max_mse, max=1)) + (1 - alpha) * cosine_similarity_phase
    similarity_loss = 1 - torch.abs(similarity)
    return similarity_loss


def compute_CRS_loss(predict_time_series_1: torch.Tensor, predict_time_series_2: torch.Tensor,
                     actual_time_series_1: torch.Tensor, actual_time_series_2: torch.Tensor) -> torch.Tensor:
    """

    @param predict_time_series_1:
    @param predict_time_series_2:
    @param actual_time_series_1:
    @param actual_time_series_2:
    @return:
    """
    def _pearson_corrcoef(x: torch.Tensor, y: torch.Tensor):
        x = torch.nan_to_num(x, nan=0.0)
        y = torch.nan_to_num(y, nan=0.0)

        x_mean = x.mean()
        y_mean = y.mean()
        cov = ((x - x_mean) * (y - y_mean)).mean()
        x_std = ((x - x_mean).pow(2)).mean().sqrt()
        y_std = ((y - y_mean).pow(2)).mean().sqrt()
        return cov / ((x_std * y_std) + 0.0000001)

    predict_correlation = _pearson_corrcoef(predict_time_series_1, predict_time_series_2)
    actual_correlation = _pearson_corrcoef(actual_time_series_1, actual_time_series_2)
    correlation_loss = torch.abs(predict_correlation - actual_correlation)
    return correlation_loss


def cal_loss_rule_multi_reverse_scale(predict_features, predict_map, device, relation_accc_zone, src_features, src_map):
    dict_check_threshold = {
        95: [['ACCCCC', 'ACCCSC'], [0.5, -2, 2]]
    }

    check_loss = {k: torch.tensor(0.0, device=device) for k in dict_check_threshold}
    log_cal_diff = True
    delta_min = torch.tensor(1.0, device=device)
    run_name = 'runStatus'
    run_threshold = 0.5

    loss_list = []
    for check_id, val in dict_check_threshold.items():
        sub_loss = torch.tensor(-9998.0, device=device)
        ls_loss = []
        if  check_id == 95:
            all_temps = []
            all_conditions = []
            for eq_type in ['ACCCCC', 'ACCCSC']:
                if eq_type not in predict_map or 'condWaterOutTemp' not in predict_map[eq_type]['target']:
                    continue
                valid_eqs = list(predict_features.get(eq_type, {}).keys())
                for eq_id in valid_eqs:
                    data_eq = predict_features[eq_type][eq_id]['condWaterOutTemp'].to(device)
                    run_status = src_features[eq_type][eq_id][run_name][:, -data_eq.shape[1]:].to(device)
                    eq_condition = (run_status > run_threshold).float()
                    masked_data_eq = data_eq * eq_condition
                    all_temps.append(masked_data_eq)
                    all_conditions.append(eq_condition)
            if not all_temps:
                continue
            stacked_temps = torch.stack(all_temps, dim=0)  # Shape: [num_devices, batch_size, seq_len]
            stacked_conditions = torch.stack(all_conditions, dim=0)  # Shape: [num_devices, batch_size, seq_len]
            combined_condition = stacked_conditions.sum(dim=0).bool()
            if not combined_condition.any():
                continue
            masked_stacked_temps = stacked_temps[:, combined_condition]
            if masked_stacked_temps.numel() == 0:
                continue
            max_temp_diff = masked_stacked_temps.max(dim=0)[0] - masked_stacked_temps.min(dim=0)[0]
            diff_condition = F.softplus(dict_check_threshold[check_id][1][1] - max_temp_diff) + F.softplus(max_temp_diff - dict_check_threshold[check_id][1][2])
            if log_cal_diff:
                diff_condition = torch.log(diff_condition + delta_min)
            valid_mask = torch.isfinite(diff_condition)
            valid_diff_condition = diff_condition[valid_mask]
            if valid_diff_condition.numel() > 0:
                sub_loss = torch.mean(valid_diff_condition)
                loss_list.append(sub_loss)
                ls_loss.append(sub_loss)
        check_loss[check_id] = sub_loss
        if ls_loss:
            check_loss[check_id] = torch.sum(torch.stack(ls_loss))

    return_tag = False
    loss = torch.tensor(0.0, device=device)
    if loss_list:
        return_tag = True
        loss = torch.sum(torch.stack(loss_list))
    return return_tag, loss, check_loss
