import numpy as np
import tonic
from catatonic.env_wrappers import apply_wrapper
import torch


BOUND = 3.14


def env_tonic_compat(env, preid=5, parallel=1, sequential=1):
    """
    Applies wrapper for tonic and passes random seed.
    """
    if "ostrich" in env:
        return lambda identifier=0: apply_wrapper(eval(env))

    elif "biped" in env:

        def build_env(identifier=0):
            id_eff = preid * (parallel * sequential) + identifier
            build = env[:-1]
            build = build + f",identifier={id_eff})"
            return apply_wrapper(eval(build))

    else:
        return lambda identifier=0: apply_wrapper(eval(env))
    return build_env


def print_data(entropy, state_buff):
    max_state = np.ones(state_buff[0].shape[0]) * 100000
    min_state = np.zeros(state_buff[0].shape[0])
    for idx in range(state_buff[0].shape[0]):
        max_state[idx] = np.max([x[idx] for x in state_buff])
        min_state[idx] = np.min([x[idx] for x in state_buff])
    print(f"Entropy: {entropy}")
    print("(Minimum, Maximum):")
    for idx in range(max_state.shape[0]):
        print(f"j_{idx}: ({min_state[idx]}, {max_state[idx]})")


def compute_entropy(states):
    total_entr = 0
    for jdx in range(states[0].shape[0]):
        # Nbins = 20
        # Nbins = 50
        Nbins = 100
        shoulder = [x[jdx] for x in states]
        x = np.linspace(-BOUND, BOUND, Nbins)
        counts = np.zeros_like(x)
        for idx, s in enumerate(shoulder):
            for bin_idx in range(x.shape[0] - 1):
                if s > x[bin_idx] and s < x[bin_idx + 1]:
                    counts[bin_idx] += 1
        # total_entr = np.std(counts)
        # np.sum(np.)
        total_count = np.sum(counts)
        if total_count == 0:
            raise Exception(
                "total count zero. Probably because the values exceed the BOUNDs of"
                f" the probed x. Last value was: {s}. Bounds are: [{-BOUND}, {BOUND}]"
            )
        entr = 0
        for c in counts:
            p = c / total_count
            entr -= p * np.log(p + 1e-4)
        total_entr += np.exp(entr)  # * multiplier
    return total_entr


def compute_mc(states, action, next_states):
    pass

def reduce_actor_observations(observations):
    if isinstance(observations, torch.Tensor):
        reduced_observations = observations.clone()
        #reduced_observations[:, :-97] = 0
        #reduced_observations[:, :-37] = 0
        return reduced_observations.detach()
    else:
        reduced_observations = observations.copy()
        #reduced_observations[:, :-97] = 0
        #reduced_observations[:, :-37] = 0
        return reduced_observations
