"""Common functions you may find useful in your implementation."""
import os
import random
import torch
import numpy as np
from utils.arguments import get_args
from typing import Optional
from torch.distributions.categorical import Categorical
from torch import einsum
from einops import reduce

import scipy.sparse as sps

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class CategoricalMasked(Categorical):

    def __init__(self, logits: torch.Tensor, mask: Optional[torch.Tensor] = None):
        self.mask = mask
        self.batch, self.nb_action = logits.size()
        if mask is None:
            super(CategoricalMasked, self).__init__(logits=logits)
        else:
            self.mask_value = torch.finfo(logits.dtype).min
            logits.masked_fill_(~self.mask, self.mask_value)
            super(CategoricalMasked, self).__init__(logits=logits)

    def entropy(self):
        if self.mask is None:
            return super().entropy()
        # Elementwise multiplication
        p_log_p = einsum("ij,ij->ij", self.logits, self.probs)
        # Compute the entropy with possible action only
        p_log_p = torch.where(
            self.mask,
            p_log_p,
            torch.tensor(0, dtype=p_log_p.dtype, device=p_log_p.device),
        )
        return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)


def append_new_dimension(args, transition_matrix, reward_matrix, state_action_matrix, all_indices, delta):
    new_transitions_row = np.full(
        (1, args.num_actions, transition_matrix.shape[2]), delta, dtype=np.float64
    )
    row_concat = np.concatenate([transition_matrix, new_transitions_row], axis=0)
    new_transitions_col = np.full(
        (row_concat.shape[0], args.num_actions, 1), delta, dtype=np.float64
    )
    transition_matrix = np.concatenate([row_concat, new_transitions_col], axis=2)

    new_reward_row = np.full(
        (1, args.num_actions, reward_matrix.shape[2]), delta, dtype=np.float64
    )
    reward_matrix = np.concatenate([reward_matrix, new_reward_row], axis=0)
    new_state_action_matrix = torch.zeros((1, state_action_matrix.shape[1])).to(device)
    state_action_matrix = torch.cat(
        [state_action_matrix, new_state_action_matrix], dim=0
    )
    all_indices.append(transition_matrix.shape[0] - 1)
    return transition_matrix, reward_matrix, state_action_matrix, all_indices


def append_new_dimension_torch(args, transition_matrix, reward_matrix, state_action_matrix, all_indices, delta):
    new_transitions_row = torch.full(
        (1, args.num_actions, transition_matrix.shape[2]), delta, dtype=torch.float64
    )
    row_concat = torch.cat([transition_matrix, new_transitions_row], dim=0)
    new_transitions_col = torch.full(
        (row_concat.shape[0], args.num_actions, 1), delta, dtype=torch.float64
    )
    transition_matrix = torch.cat([row_concat, new_transitions_col], dim=2)

    new_reward_row = torch.full(
        (1, args.num_actions, reward_matrix.shape[2]), delta, dtype=torch.float64
    )
    reward_matrix = torch.cat([reward_matrix, new_reward_row], dim=0)
    new_state_action_matrix = torch.zeros((1, state_action_matrix.shape[1]))
    state_action_matrix = torch.cat(
        [state_action_matrix, new_state_action_matrix], dim=0
    )
    all_indices.append(transition_matrix.shape[0] - 1)
    return transition_matrix, reward_matrix, state_action_matrix, all_indices


def append_new_dim_sparse(args, transition_matrix, state_action_matrix, all_indices):
    new_transitions_row = sps.csr_matrix(
        (1 * args.num_actions, transition_matrix.shape[1]), dtype=np.float64
    )
    row_concat = sps.vstack([transition_matrix, new_transitions_row], format="csr")
    new_transitions_col = sps.csr_matrix((row_concat.shape[0], 1), dtype=np.float64)
    transition_matrix = sps.hstack([row_concat, new_transitions_col], format="csr")
    new_state_action_matrix = torch.zeros((1, state_action_matrix.shape[1]))
    state_action_matrix = torch.cat(
        [state_action_matrix, new_state_action_matrix], dim=0
    )
    all_indices.append(state_action_matrix.shape[0] - 1)
    return transition_matrix, state_action_matrix, all_indices


def total_variation_distance(
        estimated_distribution: np.ndarray, true_distribution: np.ndarray, use_max=True
) -> int:
    """
    Consider P and Q are the probability measures on a sigma algebra F.
    In the finite case, the quantities sup{|P(A) - Q(A)|: A in F} and 1/2 * sum(|P(A) - Q(A)|) are equivalent.
    Reference: Proposition 4.2 on page 48 of Markov chains and mixing times by Levin, Peres, and Wilmer.
    """
    if use_max:
        tvd = np.max(abs(estimated_distribution - true_distribution))
    else:
        tvd = np.sum(abs(estimated_distribution - true_distribution)) / 2
    return tvd.item()


def relative_distance(
        estimated_distribution: np.ndarray,
        true_distribution: np.ndarray,
        use_max: bool = True,
) -> int:
    if use_max:
        rel_dist = np.max(
            abs((estimated_distribution - true_distribution / true_distribution))
        )
    else:
        rel_dist = (
                np.sum(
                    abs((estimated_distribution - true_distribution) / true_distribution)
                )
                / 2
        )
    return rel_dist.item()


def set_seeds(seed: int, env):
    """
    :type seed: int
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    env.seed(seed)

def set_device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return device
