from typing import Any

import numpy as np


def get_corrupt_params(corrupt: str, means: np.ndarray, H_context: int):
    if corrupt == "":
        return corrupt, [], 0, 0, means

    corrupt_type, corrupt_frac = corrupt.split("frac")
    corrupt_frac = float(corrupt_frac)

    corrupted_steps = get_corrupted_steps(corrupt_frac, H_context)
    corrupted_means, corrupt_magnitude = get_corrupted_means(corrupt_type, means)

    return corrupt_type, corrupted_steps, corrupt_magnitude, corrupt_frac, corrupted_means


def get_corrupted_steps(frac: float, H_context: int) -> np.ndarray:
    steps = np.arange(H_context)
    np.random.shuffle(steps)
    corrupted_steps = steps[: int(H_context * frac)]

    return corrupted_steps


def get_corrupted_means(corrupt_type: str, means: np.ndarray) -> tuple[np.ndarray, Any]:
    if corrupt_type.startswith("gaussian"):
        corrupt_magnitude = float(corrupt_type.removeprefix("gaussian"))
        return means, corrupt_magnitude
    elif corrupt_type == "special1":
        corrupted_means_1, _ = get_corrupted_means("change2top", means)
        corrupted_means_2, _ = get_corrupted_means("changemeanadv", means)

        corrupted_means_both = np.stack((corrupted_means_1, corrupted_means_2), axis=0)
        if means.ndim == 2:
            choose_which = np.random.choice(np.array([0, 1]), means.shape[0])
            corrupted_means = corrupted_means_both[choose_which, np.arange(means.shape[0]), :]
        elif means.ndim == 1:
            choose_which = np.random.choice(np.array([0, 1]))
            corrupted_means = corrupted_means_both[choose_which, :]
        else:
            raise NotImplementedError()
    elif corrupt_type == "special2":
        corrupted_means_1, _ = get_corrupted_means("changelowest", means)
        corrupted_means_2, _ = get_corrupted_means("changemeanadv", means)

        corrupted_means_both = np.stack((corrupted_means_1, corrupted_means_2), axis=0)
        if means.ndim == 2:
            choose_which = np.random.choice(np.array([0, 1]), means.shape[0])
            corrupted_means = corrupted_means_both[choose_which, np.arange(means.shape[0]), :]
        elif means.ndim == 1:
            choose_which = np.random.choice(np.array([0, 1]))
            corrupted_means = corrupted_means_both[choose_which, :]
        else:
            raise NotImplementedError()
    elif corrupt_type == "changemeanadv":
        sort = np.argsort(means, axis=-1)
        sorted_reversed = np.take_along_axis(means, sort, -1)[..., ::-1]
        corrupted_means = np.take_along_axis(sorted_reversed, np.argsort(sort, axis=-1), axis=-1)
    elif corrupt_type == "change2top":
        sort = np.argsort(means, axis=-1)

        ind_max = sort[..., -1]
        ind_2ndmax = sort[..., -2]

        corrupted_means = np.array(means)
        if means.ndim == 2:
            corrupted_means[np.arange(corrupted_means.shape[0]), ind_max] = means[np.arange(means.shape[0]), ind_2ndmax]
            corrupted_means[np.arange(corrupted_means.shape[0]), ind_2ndmax] = means[np.arange(means.shape[0]), ind_max]
        elif means.ndim == 1:
            corrupted_means[ind_max] = means[ind_2ndmax]
            corrupted_means[ind_2ndmax] = means[ind_max]
        else:
            raise NotImplementedError()
    elif corrupt_type == "changelowest":
        maxs: np.ndarray = means.max(axis=-1, keepdims=True)
        mins: np.ndarray = means.min(axis=-1, keepdims=True)

        corrupted_means = np.array(means)
        corrupted_means[corrupted_means == mins] = 10 * np.abs(maxs).squeeze()
    elif corrupt_type == "change2ndhighest":
        corrupted_means = np.array(means)

        maxs: np.ndarray = means.max(axis=-1, keepdims=True)
        max2nd: np.ndarray = means.max(axis=-1, keepdims=True, where=(means != maxs), initial=means.min())

        corrupted_means[corrupted_means == max2nd] = 10 * np.abs(maxs).squeeze()
    else:
        corrupted_means = means

    return corrupted_means, None
