import numpy as np

from numpy import ndarray as Array


def skill(ground_truth: Array, particles: Array) -> float:
    """
    Compute the mean of the skills (RMSE of the ensemble mean) at each filtering steps.
    Input(s):
        - ground_truth (Array): ground truth with dimension (num_assim_steps, 1, 128, 128).
        - particles (Array): particles of the filter at each filtering step with dimension (num_assim_steps + 1, num_particles, 1, 128, 128).
    Returns:
        - skill (float): mean of skills at each filtering step.
    """
    all_skills, K = 0., ground_truth.shape[0]
    for k in range(1, K + 1):
        gt_step = ground_truth[k - 1].squeeze(0)
        particles_step = particles[k]
        ensemble_mean = np.mean(particles_step, axis=0).squeeze(0)
        skill_step = (gt_step - ensemble_mean) ** 2
        skill_step = np.sqrt(np.mean(skill_step))
        all_skills += skill_step
    skill = all_skills / K
    return skill

def spread(particles: Array) -> float:
    """
    Compute the mean of the spreads at each fitering steps.
    Input(s):
        - particles (Array): particles of the filter at each filtering step with dimension (num_assim_steps + 1, num_particles, 3).
    Returns:
        - spread (float): mean of spreads at each filtering step.
    """
    all_spreads, K, N = 0., particles.shape[0] - 1, particles.shape[1]
    for k in range(1, K + 1):
        particles_step = particles[k]
        ensemble_mean = np.mean(particles_step, axis=0)
        var_step = (1.0 / (N - 1.0)) * np.sum((particles_step - ensemble_mean[None,:]) ** 2, axis=0)
        spread_step = np.sqrt(np.mean(var_step))
        all_spreads += spread_step
    spread = all_spreads / K
    return spread

def spread_to_skill(ground_truth: Array, particles: Array) -> float:
    """
    Compute the average spread to average skill ratio.
    Input(s):
        - ground_truth (Array): ground truth with dimension (num_assim_steps, 3).
        - particles (Array): particles of the filter at each filtering step with dimension (num_assim_steps + 1, num_particles, 3).
    Returns:
        - ratio (float): average spread to average skill ratio.
    """
    skill_ = skill(ground_truth=ground_truth, particles=particles)
    spread_ = spread(particles=particles)
    ratio = spread_ / skill_
    return ratio

def compute_L1_distance(u: Array, v: Array) -> float:
    """
    Compute the L1 distance between two states.
    Input(s):
        - u (Array): first state with dimension (3,).
        - v (Array): second state with dimension (3,).
    Returns:
        - dist (float): L1 distance between the two states.
    """
    return float(np.sum(np.abs(u - v)))


def CRPS(ground_truth: Array, particles: Array) -> float:
    """
    Compute the average of the CRPS (Continuous Ranked Probability Score) at each filtering step.
    Input(s):
        - ground_truth (Array): ground truth with dimension (num_assim_steps, 3).
        - particles (Array): particles of the filter at each filtering step with dimension (num_assim_steps + 1, num_particles, 3).
    Returns:
        - crps (float): average spread to average skill ratio.
    """
    all_crps, K, N = 0., ground_truth.shape[0], particles.shape[1]
    for k in range(1, K + 1):
        crps_step = 0.0
        gt_step = ground_truth[k - 1]
        particles_step = particles[k]
        distances_to_gt = np.sum(np.abs(gt_step[None,:] - particles_step), axis=-1)
        crps_step += float(np.mean(distances_to_gt))
        sum_distances = 0.0
        for i in range(N):
            for j in range(N):
                sum_distances += compute_L1_distance(particles_step[i], particles_step[j])
        crps_step -= (1.0 / (2 * N * (N - 1.0))) * sum_distances
        all_crps += crps_step
    crps = all_crps / K
    return crps
