import time
import random
import logging
import warnings
from typing import Any, Dict, List, Tuple

import numpy as np
import matplotlib.pyplot as plt
import tqdm
import matplotlib
import torch

import gym
import gym.utils

import tools
import decomp
import caching
import plotting
import pytorch_models
import path_config

"""
from: https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
"""

logging_format = "%(lineno)4s: %(asctime)s: %(message)s"

logging_level = 15
logging.basicConfig(level=logging_level,
                    format=logging_format)

logger = logging.getLogger(__name__)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

LIMIT_X = 2.4
LIMIT_THETA = 12 * (2 * np.pi) / 360
LIMIT_XDOT = +2.75
LIMIT_THETADOT = +3.75

tau = 0.02

STATE_LOWER = np.array([-LIMIT_X, -LIMIT_XDOT, -LIMIT_THETA, -LIMIT_THETADOT])
STATE_UPPER = np.array([+LIMIT_X, +LIMIT_XDOT, +LIMIT_THETA, +LIMIT_THETADOT])


def box_repr(lower: np.ndarray, upper: np.ndarray) -> str:
    to_join = [
        "[{:+.4f}, {:+.4f}]".format(lw.item(), up.item()) for lw, up in zip(lower, upper)
    ]
    r = " x ".join(to_join)
    return r


def _softmax(x: np.ndarray) -> np.ndarray:
    s = np.exp(x)
    return s / np.sum(s)


def predict(
    state: np.ndarray,
    policy: torch.nn.Module,
    is_deterministic_action: bool,
    polytope_rules: List[tuple],
) -> Tuple[torch.Tensor, torch.Tensor]:
    state_torch = torch.from_numpy(state).type(torch.FloatTensor)
    action_logits = policy(state_torch)

    use_polytope_rules = len(polytope_rules) > 0
    if use_polytope_rules:
        polytope_rule0 = polytope_rules[0]
        polytope_rule1 = polytope_rules[1]

        assert 0 == polytope_rule0[1]
        assert 1 == polytope_rule1[1]

        c0 = polytope_rule0[0]
        c1 = polytope_rule1[0]

        override0 = np.all(c0 @ np.vstack(state) >= 0)
        override1 = np.all(c1 @ np.vstack(state) >= 0)
        if override0:
            action_logits = action_logits + torch.Tensor([1e20, 0])
        elif override1:
            action_logits = action_logits + torch.Tensor([0, 1e20])

    action_probs = torch.nn.functional.softmax(action_logits, dim=-1)
    distribution = torch.distributions.Categorical(action_probs)

    if is_deterministic_action:
        _, action = action_logits.max(0)
    else:
        action = distribution.sample()

    if use_polytope_rules:
        if override0:
            assert action == torch.Tensor([0])
        elif override1:
            assert action == torch.Tensor([1])
    action_logprob = distribution.log_prob(action).reshape(1)

    testing = True
    if testing:
        c0 = np.diag([+1, +1, -1, -1])
        c1 = np.diag([-1, -1, +1, +1])

        override0 = np.all(c0 @ np.vstack(state) >= 0)
        override1 = np.all(c1 @ np.vstack(state) >= 0)
        """
        # if override0 and bool(action != torch.tensor(0)):
            # action == or override1:
            # print("Stop here")
        # if override1 and bool(action != torch.tensor(1)):
            # print("Stop here")
        """
    return action, action_logprob


def flip_cumsum_flip(x: np.ndarray) -> np.ndarray:
    return np.flip(np.cumsum(np.flip(x)))


def update_policy(optimizer: torch.optim.Optimizer,
                  episode_rewards: List[float],
                  episode_action_logprobs: List[torch.Tensor]) -> Tuple[float, float]:
    gamma = 0.99
    episode_reward_len = len(episode_rewards)
    discount_factor = gamma ** (np.arange(episode_reward_len, 0, -1) - 1)
    discounted_rewards = episode_rewards * discount_factor
    rewards_array = flip_cumsum_flip(discounted_rewards).tolist()

    rewards = torch.FloatTensor(rewards_array)
    rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)

    catted_actions = torch.cat(episode_action_logprobs)
    episode_losses = torch.mul(catted_actions, rewards).mul(-1)
    loss = torch.sum(episode_losses, -1)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    policy_loss = loss.item()
    policy_reward = np.sum(episode_rewards)
    return policy_loss, policy_reward


def compute_history_sign_distribution(episode_state_signs: np.ndarray) -> np.ndarray:
    dim = episode_state_signs.shape[1]

    sign_rows = tools._gen_all_01_rows(dim) * 2 - 1
    num_sign_row = np.empty((2 ** dim,))
    for idx in range(2 ** dim):
        is_sign_row = np.all(sign_rows[idx, :] == episode_state_signs, axis=1)
        num_sign_row[idx] = np.sum(is_sign_row)
    return num_sign_row


def train_policy(env_name: str,
                 num_hidden: int,
                 episodes: int,
                 polytope_rules: List[tuple]) -> Dict[str, Any]:
    tabulate_wrong_moves = True
    layer_list = build_layers(num_hidden)

    env = gym.make(env_name)
    env.seed(0)
    assert env.observation_space.shape[0] == layer_list[0].in_features
    assert env.action_space.n == layer_list[-1].out_features

    learning_rate = 0.01

    policy = pytorch_models.Net(layer_list)
    optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)

    log_every = 100
    reward_threshold = env.spec.reward_threshold

    frac_wrong_action_history = []
    loss_history = []
    reward_history = []
    extreme_state_history = []
    scores = []
    octant_inclusion_history = []

    replace_initial_state = False
    # replace_initial_state = True

    for episode in range(episodes):
        episode_states = []
        episode_rewards = []
        episode_action_logprobs = []
        episode_actions = []

        state = env.reset()
        if replace_initial_state:
            max_angle = 12 * 2 * np.pi / 360
            max_position = 2.4

            position_to_replace = np.random.uniform(
                -1 * max_position, +1 * max_position
            )
            angle_to_replace = np.random.uniform(-1 * max_angle, +1 * max_angle)
            env.state[0] = position_to_replace
            env.state[2] = angle_to_replace

        for idx in range(1000):
            action, action_logprob = predict(state, policy, False, polytope_rules)
            state, reward, done, info = env.step(action.item())
            episode_states.append(state)

            episode_rewards.append(reward)
            episode_action_logprobs.append(action_logprob)
            episode_actions.append(action)

            if done:
                break
            else:
                # assert np.all(STATE_LOWER <= state) and \
                #        np.all(state <= STATE_UPPER)
                for i in range(4):
                    if not STATE_LOWER[i] <= state[i] <= STATE_UPPER[i]:
                        warnings.warn("Episode {}, dim {}: Not {} <= {} <= {}".format(idx, i, STATE_LOWER[i], state[i], STATE_UPPER[i]))

        episode_states_array = np.array(episode_states)
        episode_state_signs = np.sign(episode_states_array)

        num_sign_row = compute_history_sign_distribution(episode_state_signs)
        octant_inclusion_history.append(num_sign_row)
        episode_states_before_failure = episode_states_array[:-1, :]
        extreme_state = np.max(np.abs(episode_states_before_failure), axis=0)

        policy_loss, policy_reward = update_policy(
            optimizer, episode_rewards, episode_action_logprobs
        )
        loss_history.append(policy_loss)

        if tabulate_wrong_moves:
            es = np.array(episode_states)
            ea = np.array(episode_actions)
            # ep = np.exp(np.array(episode_action_logprobs))
            is_wrong, is_right = classify_definite_states(es, ea)
            frac_wrong_action = np.mean(is_wrong)
        else:
            frac_wrong_action = np.nan

        frac_wrong_action_history.append(frac_wrong_action)
        reward_history.append(policy_reward)
        extreme_state_history.append(extreme_state)
        scores.append(idx)

        mean_score = np.mean(scores[-100:])

        if episode % log_every == 0:
            msg = "Episode {:>4}\tAverage length: {:>.2f}".format(episode, mean_score)
            logger.info(msg)

        if mean_score > reward_threshold:
            msg = "Solved after {} episodes! Running average is now {}. Last episode ran to {} time steps.".format(
                episode, mean_score, time
            )
            logger.info(msg)
            break

    training_results = {
        "policy": policy,
        "octant_inclusion_history": octant_inclusion_history,
        "frac_wrong_action_history": frac_wrong_action_history,
        "loss_history": loss_history,
        "reward_history": reward_history,
        "extreme_state_history": extreme_state_history,
    }
    return training_results


def show_correct_balancing(env_name: str,
                           policy: pytorch_models.Net,
                           polytope_rules: List[tuple]) -> None:
    env = gym.make(env_name)
    env.seed(0)

    # num_steps = 1000
    num_steps = 300
    state = env.reset()
    for step in range(num_steps):
        action, _ = predict(state, policy, True, polytope_rules)
        env.render()
        state, reward, done, info = env.step(action.item())
    env.close()


def compute_definite_states(state: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

    eps_scale = 1 / 1000
    x_eps = 4.8 * eps_scale
    theta_eps = 2 * np.pi / 15 * eps_scale
    xdot_eps = 2 * LIMIT_XDOT * eps_scale
    thetadot_eps = 2 * LIMIT_THETADOT * eps_scale

    is_cart_l = (state[:, 0] < -1 * x_eps)
    is_cart_r = (state[:, 0] > +1 * x_eps)

    is_cart_moving_l = (state[:, 1] < -1 * xdot_eps)
    is_cart_moving_r = (state[:, 1] > +1 * xdot_eps)

    is_pole_l = (state[:, 2] < -1 * theta_eps)
    is_pole_r = (state[:, 2] > +1 * theta_eps)

    is_pole_moving_l = (state[:, 3] < -1 * thetadot_eps)
    is_pole_moving_r = (state[:, 3] > +1 * thetadot_eps)

    possibly0 = np.stack((is_cart_l, is_cart_moving_l, is_pole_r, is_pole_moving_r), axis=1)
    possibly1 = np.stack((is_cart_r, is_cart_moving_r, is_pole_l, is_pole_moving_l), axis=1)

    must_be_0 = np.all(possibly0, axis=1)
    must_be_1 = np.all(possibly1, axis=1)
    return must_be_0, must_be_1


def classify_definite_states(state: np.ndarray,
                             action: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    must_be_0, must_be_1 = compute_definite_states(state)

    is_wrongly_0 = np.logical_and(must_be_1, action == 0)
    is_wrongly_1 = np.logical_and(must_be_0, action == 1)

    is_rightly_0 = np.logical_and(must_be_0, action == 0)
    is_rightly_1 = np.logical_and(must_be_1, action == 1)

    is_right = np.logical_or(is_rightly_0, is_rightly_1)
    is_wrong = np.logical_or(is_wrongly_0, is_wrongly_1)
    return is_wrong, is_right


def _run_trial(trial_length: int,
               env_name: str,
               policy: pytorch_models.Net) -> Dict[str, Any]:
    states_np = np.empty((trial_length, 4))
    actions_np = np.empty((trial_length,))
    actions_logprobs_np = np.empty((trial_length,))

    env = gym.make(env_name)
    state = env.reset()
    polytope_rules = []

    for step in range(trial_length):
        action, action_logprob = predict(state, policy, True, polytope_rules)
        is_done = env.steps_beyond_done is not None
        if is_done:
            break
        state, reward, done, info = env.step(action.item())
        states_np[step, :] = state
        actions_np[step] = action.item()
        actions_logprobs_np[step] = action_logprob.item()

    env.close()
    num_steps = step

    action_probs_np = np.exp(actions_logprobs_np)
    is_high_prob = action_probs_np > .51

    must_be_0, must_be_1 = compute_definite_states(states_np)
    is_definite = np.logical_or(must_be_0, must_be_1)

    abs_definite_state = np.abs(states_np[is_definite, :])
    mean_abs_definite_state = np.mean(abs_definite_state, axis=0)
    assert not np.any(np.logical_and(must_be_0, ~is_definite))
    assert not np.any(np.logical_and(must_be_1, ~is_definite))
    is_wrong, is_right = classify_definite_states(states_np, actions_np)

    is_wrong_and_high_prob = np.logical_and(is_high_prob, is_wrong)

    maxabs_state = np.max(np.abs(states_np), axis=0)
    maxabs_x = maxabs_state[0]
    maxabs_theta = maxabs_state[2]

    avg_must_be_0 = np.mean(must_be_0)
    avg_must_be_1 = np.mean(must_be_1)

    avg_wrong_and_high_prob = np.mean(is_wrong_and_high_prob)
    avg_wrong = np.mean(is_wrong)
    avg_right = np.mean(is_right)
    avg_definite = np.mean(is_definite)

    trial_results = {
        "num_steps": num_steps,
        "maxabs_x": maxabs_x,
        "maxabs_theta": maxabs_theta,
        "maxabs_state": maxabs_state,
        "avg_must_be_0": avg_must_be_0,
        "avg_must_be_1": avg_must_be_1,
        "avg_wrong": avg_wrong,
        "avg_wrong_and_high_prob": avg_wrong_and_high_prob,
        "avg_right": avg_right,
        "avg_definite": avg_definite,
        "mean_abs_definite_state": mean_abs_definite_state
    }
    return trial_results


def analyze_policy(env_name: str,
                   policy: pytorch_models.Net,
                   polytope_rules: List[tuple]) -> Dict[str, Any]:
    # num_trials = 50
    num_trials = 100
    # num_trials = 1000

    trial_length = 200

    trial_lengths = np.empty((num_trials,))
    trial_maxabsxs = np.empty((num_trials,))
    trial_maxabsthetas = np.empty((num_trials,))
    trial_avgwrong = np.empty((num_trials,))
    trial_avgright = np.empty((num_trials,))
    trial_avgdefinite = np.empty((num_trials,))
    trial_must_be_0 = np.empty((num_trials,))
    trial_must_be_1 = np.empty((num_trials,))
    trial_mean_abs_definite_state = np.empty((num_trials, 4))
    trial_max_absstate = np.empty((num_trials, 4))
    trial_avg_wrong_and_high_prob = np.empty((num_trials,))

    for trial_idx in tqdm.tqdm(range(num_trials), total=num_trials):
        trial_results = _run_trial(trial_length, env_name, policy)
        trial_avg_wrong_and_high_prob[trial_idx] = trial_results["avg_wrong_and_high_prob"]
        trial_lengths[trial_idx] = trial_results["num_steps"]
        trial_maxabsxs[trial_idx] = trial_results["maxabs_x"]
        trial_maxabsthetas[trial_idx] = trial_results["maxabs_theta"]
        trial_max_absstate[trial_idx, :] = trial_results["maxabs_state"]
        trial_avgwrong[trial_idx] = trial_results["avg_wrong"]
        trial_avgright[trial_idx] = trial_results["avg_right"]
        trial_avgdefinite[trial_idx] = trial_results["avg_definite"]
        trial_must_be_0[trial_idx] = trial_results["avg_must_be_0"]
        trial_must_be_1[trial_idx] = trial_results["avg_must_be_1"]
        trial_mean_abs_definite_state[trial_idx, :] = trial_results["mean_abs_definite_state"]

    # do_monte_carlo_analysis = True
    do_monte_carlo_analysis = False
    if do_monte_carlo_analysis:
        lowers0 = np.array([-2.4000, -2.7500, +0.0000, +1.3262])
        uppers0 = np.array([-1.4722, -2.0967, +0.1056, +1.6522])

        monte_carlo_samples = 10000
        weights_mat = np.random.uniform(size=(monte_carlo_samples, 4))
        points = np.hstack(lowers0) * weights_mat + \
                 np.hstack(uppers0) * (1 - weights_mat)
        points_torch = torch.Tensor(points)
        logits = policy(points_torch)
        logits_np = logits.detach().numpy()

        assert np.all(0 == np.argmax(logits_np, axis=1))

    policy_analysis = {
        "trial_lengths": trial_lengths,
        "trial_maxabsxs": trial_maxabsxs,
        "trial_maxabsthetas": trial_maxabsthetas,
        "trial_max_absstate": trial_max_absstate,
        "trial_avgright": trial_avgright,
        "trial_avgwrong": trial_avgwrong,
        "trial_avgdefinite": trial_avgdefinite,
        "trial_must_be_0": trial_must_be_0,
        "trial_must_be_1": trial_must_be_1,
        "trial_mean_abs_definite_state": trial_mean_abs_definite_state,
        "trial_avg_wrong_and_high_prob": trial_avg_wrong_and_high_prob
    }
    return policy_analysis


def find_empirical_counterexamples(env_name: str, policy):
    # render = True
    render = False
    plot_states = True

    env = gym.make(env_name)
    env.seed(0)

    num_steps = 600
    polytope_rules = []
    state_space_size = env.observation_space.shape[0]
    states = np.full((num_steps, state_space_size), np.nan)

    actions = np.full((num_steps,), np.nan)
    # sleep_time = .0

    state = env.reset()
    for idx in range(num_steps):
        # step = 0
        states[idx, :] = state
        action, _ = predict(state, policy, True, polytope_rules)

        actions[idx] = action
        if render:
            env.render()

        # Step through environment using chosen action
        state, reward, done, info = env.step(action.item())
        if done:
            logger.info("Breaking since done")
            break

    env.close()

    states = states[: idx + 1, :]
    actions = actions[: idx + 1]

    if plot_states:
        plot_state_history(states)

    c0 = np.diag([+1, +1, -1, -1])
    c1 = np.diag([-1, -1, +1, +1])

    is_condition0 = np.all(states @ c0 >= 0, axis=1)
    is_condition1 = np.all(states @ c1 >= 0, axis=1)

    is_action0 = actions == 0
    is_action1 = actions == 1

    avg0 = np.mean(is_condition0)
    avg1 = np.mean(is_condition1)
    avg00 = np.mean(is_condition0 & is_action0)
    avg11 = np.mean(is_condition1 & is_action1)

    if avg00 < avg0:
        bad_row0 = is_condition0 & ~is_action0
        bad_row_idx0 = np.argwhere(bad_row0)

    if avg11 < avg1:
        bad_row1 = is_condition1 & ~is_action1
        action, _ = predict(state, policy, True, polytope_rules)

    logger.info("P[condition 0 and action 0] = {}".format(avg00))
    logger.info("P[condition 0] = {}".format(avg0))

    logger.info("P[condition 1 and action 1] = {}".format(avg11))
    logger.info("P[condition 1] = {}".format(avg1))


def analyze_weird_state(env_name: str, policy):
    polytope_rules = []
    state = np.array([-1.18991656, -0.13554511, 0.07258647, 0.04905184])
    action, _ = predict(state, policy, True, polytope_rules)
    done = False

    num_steps = 0
    max_steps = 1000
    # step_size = .0005
    step_size = 0.002
    step_direction = -1 * np.array([1, 0, 0, 0])

    while not done:
        action_at_state, _ = predict(state, policy, True, polytope_rules)
        new_state = state + step_size * step_direction
        num_steps = 1 + num_steps
        state = new_state
        done = (torch.tensor(0) != action_at_state) or (num_steps >= max_steps)

    state = np.array([-1.18991656, -0.13554511, 0.07258647, 0.04905184])

    env = gym.make(env_name)
    env.seed(0)
    env.env.state = state

    for idx in range(10):
        env.step(0)
        # print(idx)
        env.render()
        time.sleep(0.15)
    env.close()


def analyze_state_transition(env_name: str,
                             state: np.ndarray,
                             action: int) -> np.ndarray:
    env = gym.make(env_name)
    env.seed(0)
    env.reset()
    env.env.state = state
    env.step(action)
    env.render()
    return env.env.state


def build_layers(num_hidden: int) -> List[torch.nn.Module]:
    state_space_dim = 4
    action_space_dim = 2

    input_dim = state_space_dim
    output_dim = action_space_dim
    hidden_layer_widths = [num_hidden]

    # include_bias = False
    include_bias = True
    layers = pytorch_models.build_relu_layers(
        input_dim, hidden_layer_widths, output_dim, include_bias
    )
    return layers


def _replace_numerical_inf_with_actual_inf(x: np.ndarray) -> np.ndarray:
    x_sign = np.sign(x)
    is_numerical_inf = np.abs(x) > 1e35
    x_replaced = np.where(is_numerical_inf, x_sign * np.inf, x)
    return x_replaced


def build_definite_spaces(
    input_layer_bounds: Tuple[np.ndarray, np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
    lower, upper = input_layer_bounds
    h_bounds = tools.build_bounding_box_h_form(lower, upper)

    # position and velocity >= 0,
    # angle and angle velocity <= 0 thus definitely push left
    a0 = np.diag([+1, +1, -1, -1])
    # position and velocity <= 0,
    # angle and angle velocity >= 0 thus definitely push right
    a1 = np.diag([-1, -1, +1, +1])

    space0_ineq = np.hstack((np.zeros((4, 1)), a0))  # definitely push left / choose 0
    space1_ineq = np.hstack((np.zeros((4, 1)), a1))  # definitely push right / choose 1

    append_input_layer_bounds = True
    if append_input_layer_bounds:
        space0_ineq = np.vstack((space0_ineq, h_bounds))
        space1_ineq = np.vstack((space1_ineq, h_bounds))

    return space0_ineq, space1_ineq


def _volume_box_analysis(analyze_preimage: List[dict],
                         analyze_space: np.ndarray) -> Any:
    num_preimage = len(analyze_preimage)
    intersection_volumes = np.full((num_preimage,), np.nan)
    boxes = [None] * num_preimage
    intersection_v_reprs = [None] * num_preimage
    dim = analyze_space.shape[1] - 1

    for idx, p in enumerate(analyze_preimage):
        # idx = 0; p = analyze_preimage[idx]
        # idx = 1; p = analyze_preimage[idx]
        # idx = -58; p = analyze_preimage[idx]
        ph = p["h"]
        if ph["is_empty"]:
            intersection_v_repr = np.empty((0, dim + 1))
            vol = 0
        else:
            h_ineq = ph["inequality"]

            intersection_h_repr = np.vstack((analyze_space, h_ineq))
            h_lin = np.empty((0, intersection_h_repr.shape[1]))
            intersection_v_repr = tools.h_to_v(intersection_h_repr, h_lin)

            assert np.all(intersection_v_repr[:, 0] == 1)
            vol = tools.compute_hull_volume(intersection_v_repr[:, 1:])

        intersection_volumes[idx] = vol
        if 0 < vol:
            lower, upper = tools.compute_maximum_volume_inner_box(intersection_h_repr)
            boxes[idx] = (lower, upper)
        # v_vertices = p["v"]["vertices"]
        intersection_v_reprs[idx] = intersection_v_repr

    return intersection_volumes, boxes, intersection_v_reprs

    # coords = [None] * num_in_preimage
    # boxes = [None] * num_in_preimage
    #
    # for idx, p in enumerate(examine_preimage):
    #     # idx = 0; p = preimage0[idx]
    #     v_repr = p["v"]
    #     vertices = v_repr["vertices"]
    #
    #     h_ineq = p["h"]["inequality"]
    #     v_repr = p["v"]["vertices"]
    #     assert np.all(1 == v_repr[:, 0])
    #
    #     volume = tools.compute_hull_volume(v_repr[:, 1:])
    #     if 0 < volume:
    #         h_both = np.vstack((examine_space, h_ineq))
    #         v_both = tools.h_to_v(h_both, h_lin)
    #         assert np.all(1 == v_both[:, 0])
    #         if 0 < tools.compute_hull_volume(v_both[:, 1:]):
    #             lower, upper = tools.compute_maximum_volume_inner_box(h_both)
    #             # print(box_repr(lower, upper))
    #             boxes[idx] = (lower, upper)
    #
    #     assert np.all(1 == vertices[:, 0])
    #     coords[idx] = vertices[:, 1:


def inversion_analysis(policy: pytorch_models.Net,
                       input_layer_bounds: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]:
    is_rational = False
    need_initial_v = True

    # desired_margin = -.1
    desired_margin = 0.0

    layers = policy.layers

    invert_classes = [0, 1]
    cache_inversion = True

    do_inversion = True
    inversion_par = {
        "do_inversion": do_inversion,
        "cache_inversion": cache_inversion,
        "desired_margin": desired_margin,
        "need_initial_v": need_initial_v,
        "input_layer_bounds": input_layer_bounds,
        "invert_classes": invert_classes,
        "is_rational": is_rational,
    }
    layer_info = decomp.build_layer_info(layers, inversion_par)

    decomps = decomp.compute_decomps(layer_info, inversion_par)
    preimage_decomps = [d[0] for d in decomps]

    definite_space0_ineq, \
    definite_space1_ineq = build_definite_spaces(input_layer_bounds)
    # definitely left,   # definitely right

    # definite_space0_ineq: x, x' <= 0, t, t' >= 0
    # definite_space1_ineq: x, x' >= 0, t, t' <= 0

    preimage0 = preimage_decomps[0]
    preimage1 = preimage_decomps[1]

    # classed 1, but should be classed zero
    intersection_volumes1, boxes1, intersection_v_reprs1 = _volume_box_analysis(preimage1,
                                                                                definite_space0_ineq)
    # classed zero, but should be classed 1
    intersection_volumes0, boxes0, intersection_v_reprs0 = _volume_box_analysis(preimage0,
                                                                                definite_space1_ineq)
    inversion_results = {
        "input_layer_bounds": input_layer_bounds,
        "intersection_volumes0": intersection_volumes0,
        "boxes0": boxes0,
        "intersection_v_reprs0": intersection_v_reprs0,
        "intersection_volumes1": intersection_volumes1,
        "boxes1": boxes1,
        "intersection_v_reprs1": intersection_v_reprs1
    }
    return inversion_results


def get_polytope_rules() -> List[tuple]:
    c0 = np.diag([+1, +1, -1, -1])
    c1 = np.diag([-1, -1, +1, +1])
    polytope_rules = [(c0, 0), (c1, 1)]
    return polytope_rules


def build_input_layer_bounds(training_results: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
    env_name = "CartPole-v1"
    env = gym.make(env_name)

    # use_actual_limits = False
    # layer_scheme = "actual_from_env"
    # layer_scheme = "empirical"
    layer_scheme = "hardcoded"

    if "actual_from_env" == layer_scheme:
        lower = np.vstack(env.observation_space.low)
        upper = np.vstack(env.observation_space.high)

        lower = _replace_numerical_inf_with_actual_inf(lower)
        upper = _replace_numerical_inf_with_actual_inf(upper)
    elif "hardcoded" == layer_scheme:
        lower = np.vstack(STATE_LOWER)
        upper = np.vstack(STATE_UPPER)
    elif "empirical" == layer_scheme:
        extreme_state_history_array = np.array(
            training_results["extreme_state_history"]
        )

        empirical_scale = np.max(extreme_state_history_array, axis=0)
        lower = -1 * np.vstack(empirical_scale)
        upper = +1 * np.vstack(empirical_scale)
    else:
        raise ValueError("Not configured")
    input_layer_bounds = (lower, upper)
    return input_layer_bounds


def exhibit_strange_behavior(env_name: str,
                              policy: pytorch_models.Net,
                              analysis_results: Dict[str, Any]):
    # boxes = analysis_results["boxes"]
    boxes0 = analysis_results["boxes0"]

    examine_idx = None
    for idx, b in enumerate(boxes0):
        if b is not None:
            print("{}: {}".format(idx, box_repr(*b)))
            examine_idx = idx

    examine_idx = 3874
    if examine_idx is not None:
        # examine_idx = 12
        lower, upper = boxes0[examine_idx]
        # w = 0.95
        w = 0.05
        fragile_state = (lower * w + upper * (1 - w)).flatten()
        action, _ = predict(fragile_state, policy, False, [])

        fragile_state_torch = torch.from_numpy(fragile_state).type(torch.FloatTensor)
        action_logits = policy(fragile_state_torch)

        print("State:")
        print(fragile_state)

        print("prescribed action is zero (push the cart left)")
        print(action)

        print("(but clearly pushing right is the correct move)")
        env = gym.make(env_name)
        env.seed(0)
        env.reset()
        env.env.state = fragile_state
        _ = env.render()

        next_state, reward, done, info = env.step(action.numpy().item())
        r = env.render()

        state_diff = next_state - fragile_state
        assert np.all(state_diff[:2] < 0), \
            "Cart should be more to the left and leftward velocity should be greater"
        assert np.all(state_diff[2:] > 0), \
            "Pole should be more to the right and rightward velocity should be greater"

        next_action, _ = predict(state_diff, policy, False, [])


def present_fitting_results(training_results: Dict[str, Any],
                            experiment_name: str):
    loss_history = training_results["loss_history"]
    reward_history = training_results["reward_history"]
    frac_wrong_action_history = training_results["frac_wrong_action_history"]

    fig, ax = plt.subplots(1, 1)
    ax.plot(frac_wrong_action_history)
    ax.set_title("Fraction of wrong actions: {}".format(experiment_name))

    fig, ax = plt.subplots(1, 1)
    ax.plot(loss_history)
    ax.set_title("Loss History: {}".format(experiment_name))

    fig, ax = plt.subplots(1, 1)
    ax.plot(reward_history)
    ax.set_title("Reward History {}".format(experiment_name))

    esh = np.array(training_results["extreme_state_history"])
    hist_states(esh)

    oih = np.array(training_results["octant_inclusion_history"])
    tot = np.vstack(np.sum(oih, axis=1))
    oif = np.divide(oih, tot)
    avg_time_in_octant = np.mean(oif, axis=0)
    dim = int(np.log2(avg_time_in_octant.shape[0]))

    sign_rows = tools._gen_all_01_rows(dim) * 2 - 1

    for idx in range(2 ** dim):
        print("{}: {:.4f}".format(sign_rows[idx, :],
                                  avg_time_in_octant[idx]))

    logger.info("1-D breakdowns: ")
    for idx in range(dim):
        pos_rows = sign_rows[:, idx] > 0
        pos_time = avg_time_in_octant[pos_rows]
        total_pos_time = np.sum(pos_time)
        logger.info("Dim {}: time positive {:.4f}".format(idx, total_pos_time))


def compute_velocity_limits():
    """
    These limits are obviously going to be loose in practice, but
    amongst all loose bounds, they enjoy the property that:
      There is an initialization and a path of actions which actually
      deliver them.
    """
    # Push left until episode ends
    #
    # Start with:
    #   Cart as right as possible: +.05
    #   Cart moving as left as quickly as possible: -.05
    #   Pole as left as possible: -.05
    #   Pole moving left left as quickly as possible: -.05

    # extreme_initial_state = np.array([+.05, -.05, -.05, +.05])
    extreme_initial_state = np.array([+.05, -.05, -.05, -.05])

    try_initial_states = [np.array([+.05, -.05, -.05, +.05]),
                          np.array([+.05, -.05, -.05, -.05]),
                          np.array([+.05, +.05, -.05, +.05]),
                          np.array([+.05, +.05, -.05, -.05])]
    fixed_action = 0

    for idx, extreme_initial_state in enumerate(try_initial_states):
        # idx = 0; extreme_initial_state = try_initial_states[idx]
        env_name = "CartPole-v1"
        env = gym.make(env_name)
        env.seed(0)
        env.reset()
        env.env.state = extreme_initial_state
        done = False

        episode_states = []
        episode_actions = []
        num_steps = 0
        while not done:
            if num_steps <= 5:
                action = 0
            else:
                action = 1
            # action = fixed_action
            state, reward, done, info = env.step(action)
            episode_states.append(state)
            episode_actions.append(action)
            num_steps += 1

        states = np.array(episode_states)
        plot_state_history(states)

        x_limit = 2.4
        theta_limit = 12 * np.pi * 2 / 360

        print("Starting point: {}".format(extreme_initial_state))
        print("Terminal x: {:.4f} (limit: {:.4f})".format(states[-1, 0], x_limit))
        print("Terminal xdot: {:.4f}".format(states[-1, 1]))
        print("Terminal theta: {:.4f} (limit: {:.4f})".format(states[-1, 2], theta_limit))
        print("Terminal thetadot: {:.4f}".format(states[-1, 3]))


def _add_bad_regions_for_octant(ax,
                                volumes: np.ndarray,
                                intersection_vreprs: List[np.ndarray]):
    projector = np.eye(4)[:, [0, 2]]
    xlim = (-LIMIT_X, +LIMIT_X)
    ylim = (-LIMIT_THETA, +LIMIT_THETA)

    num_polytopes = len(volumes)
    volume_ratios = np.full((num_polytopes,), np.nan)

    do_monte_carlo_check = True
    if do_monte_carlo_check:
        scl = np.array([-LIMIT_X, +LIMIT_THETA])
        n_points = 1000
        random_points = np.random.uniform(size=(n_points, 2)) * scl
        is_in = np.full((n_points), False)

    alpha = .05
    color_unit = tuple((.001, .001, .001))
    for idx in range(num_polytopes):
        # idx = 4038
        volume_4d = volumes[idx]
        if volume_4d > 0:
            v = intersection_vreprs[idx]
            assert np.all(1 == v[:, 0])
            proj_2d = v[:, 1:] @ projector
            nr = proj_2d.shape[0]
            proj_v_repr = np.hstack((np.ones((nr, 1)), proj_2d))

            # proj_h_repr = tools.v_to_h(proj_v_repr, np.empty((0, 3)))
            if do_monte_carlo_check:
                idx_in = tools.points_in_polytope(random_points, proj_v_repr).flatten()
                is_in[idx_in] = True
            volume_2d = tools.compute_hull_volume(proj_2d)
            volume_ratio = volume_4d / volume_2d
            volume_ratios[idx] = volume_ratio
            color = tuple(c * volume_ratio for c in color_unit)
            plotting.convex_hull_plot(ax, proj_v_repr, xlim, ylim, color, alpha=alpha)


def do_proj_plot(intersection_volumes0: np.ndarray,
                 intersection_volumes1: np.ndarray,
                 intersection_v_reprs0: List[np.ndarray],
                 intersection_v_reprs1: List[np.ndarray]) -> Tuple[matplotlib.figure.Figure,
                                                                   np.ndarray]:
    # plot_scale = 3.0
    plot_scale = 2.5
    # fig, axs = plotting.wrapped_subplot(1, 1, plot_scale=plot_scale)
    fig, ax = plt.subplots(1, 1, figsize=(plot_scale, plot_scale))

    _add_bad_regions_for_octant(ax, intersection_volumes0, intersection_v_reprs0)
    _add_bad_regions_for_octant(ax, intersection_volumes1, intersection_v_reprs1)

    formatter = matplotlib.ticker.FormatStrFormatter("%+0.1f")
    ax.yaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_formatter(formatter)

    ax.set_xlabel(r"$x$")
    ax.set_ylabel(r"$\theta$")
    fig.tight_layout()
    return fig, ax
# https://tex.stackexchange.com/questions/222268/why-isnt-pgf-honouring-my-font-selections-for-matplotlib-generated-graphics
    
    
def print_inversion_results(par: Dict[str, Any],
                            inversion_results: Dict[str, Any]):
    fig_format = par["fig_format"]
    paths = par["paths"]

    intersection_volumes0 = inversion_results["intersection_volumes0"]
    intersection_volumes1 = inversion_results["intersection_volumes1"]

    intersection_v_reprs0 = inversion_results["intersection_v_reprs0"]
    intersection_v_reprs1 = inversion_results["intersection_v_reprs1"]

    boxes0 = inversion_results["boxes0"]
    boxes1 = inversion_results["boxes1"]

    input_layer_bounds = inversion_results["input_layer_bounds"]
    lower, upper = input_layer_bounds
    total_volume = np.prod((upper - lower))

    total_bad_volume0 = np.sum(intersection_volumes0)
    total_bad_volume1 = np.sum(intersection_volumes1)

    total_bad_volume = total_bad_volume0 + total_bad_volume1

    box0_bad_volume_fraction = total_bad_volume0 / total_volume
    box1_bad_volume_fraction = total_bad_volume1 / total_volume
    total_bad_volume_fraction = total_bad_volume / total_volume
    """
    # classed 1, but should be classed zero
    intersection_volumes1, boxes1, intersection_v_reprs1 = _volume_box_analysis(preimage1,
                                                                                definite_space0_ineq)
    """
    normalized_volumes0 = intersection_volumes0 / total_volume
    normalized_volumes1 = intersection_volumes1 / total_volume
    normalized_volumes = np.concatenate((intersection_volumes0, intersection_volumes1)) / total_volume

    logger.info("Box 1 (classed 1, should be 0) bad volume fraction {:.4f}".format(box1_bad_volume_fraction))
    logger.info("Box 0 (classed 0, should be 1) bad volume fraction {:.4f}".format(box0_bad_volume_fraction))
    logger.info("Total bad volume fraction {:.4f}".format(total_bad_volume_fraction))

    quantiles = np.linspace(0, 1, normalized_volumes.size)
    volume_cumulative_distribution = np.cumsum(np.sort(normalized_volumes))
    first_nonzero = np.argmax(volume_cumulative_distribution > 0)
    plot_scale = 1.5

    if "pgf" == fig_format:
        font_family = "serif"
        plotting.initialise_pgf_plots("pdflatex", font_family)

    fig, axs = plotting.wrapped_subplot(1, 1, plot_scale=plot_scale)
    ax = axs[0, 0]
    ax.plot(quantiles[first_nonzero:],
             volume_cumulative_distribution[first_nonzero:])
    ax.grid()

    ident_str = ""
    ident = "nonempty_volume_cdf_{}".format(ident_str)
    filepath = paths["plots"]
    fig_path = plotting.smart_save_fig(fig,
                                       ident,
                                       fig_format,
                                       filepath)
    logger.info("Saved to {}".format(fig_path))

    largest_box_idx0 = np.argmax(normalized_volumes0)
    largest_box_idx1 = np.argmax(normalized_volumes1)

    largest_box0 = boxes0[largest_box_idx0]
    largest_box1 = boxes1[largest_box_idx1]

    largest_box0_lower = largest_box0[0]
    largest_box0_upper = largest_box0[1]

    largest_box1_lower = largest_box1[0]
    largest_box1_upper = largest_box1[1]

    largest_box0_volume = np.prod(largest_box0_upper - largest_box0_lower)
    largest_box1_volume = np.prod(largest_box1_upper - largest_box1_lower)

    largest_poly0_volume = intersection_volumes0[largest_box_idx0]
    largest_poly1_volume = intersection_volumes1[largest_box_idx1]

    logger.info("Polytope enclosing largest box0 volume: {:.4f}".format(largest_poly0_volume))
    logger.info("Polytope enclosing largest box1 volume: {:.4f}".format(largest_poly1_volume))

    logger.info("Largest Box0 volume: {:.4f}".format(largest_box0_volume))
    logger.info("Largest Box1 volume: {:.4f}".format(largest_box1_volume))

    br0 = box_repr(largest_box0_lower, largest_box0_upper)
    br1 = box_repr(largest_box1_lower, largest_box1_upper)

    logger.info("Largest box0: {}".format(br0))
    logger.info("Largest box1: {}".format(br1))

    if "pgf" == fig_format:
        font_family = "serif"
        plotting.initialise_pgf_plots("pdflatex", font_family)

    fig, axs = do_proj_plot(intersection_volumes0,
                            intersection_volumes1,
                            intersection_v_reprs0,
                            intersection_v_reprs1)

    ident = "proj_plot"
    filepath = paths["plots"]

    fig_path = plotting.smart_save_fig(fig,
                                       ident,
                                       fig_format,
                                       filepath)
    logger.info("Saved to {}".format(fig_path))


def present_policy_analysis(env_name: str,
                            policy: pytorch_models.Net,
                            policy_analysis: Dict[str, Any],
                            polytope_rules: List[tuple]):
    # trial_max_absstate = policy_analysis['trial_max_absstate']
    # hist_states(trial_max_absstate)

    # This policy achieves good results
    # do_show_correct_balancing = True
    do_show_correct_balancing = False
    if do_show_correct_balancing:
        show_correct_balancing(env_name, policy, polytope_rules)

    mad_definite = np.mean(policy_analysis['trial_mean_abs_definite_state'], axis=0)

    # policy_analysis.keys()
    r = np.mean(policy_analysis["trial_avgright"])
    w = np.mean(policy_analysis["trial_avgwrong"])
    d = np.mean(policy_analysis["trial_avgdefinite"])
    m0 = np.mean(policy_analysis["trial_must_be_0"])
    m1 = np.mean(policy_analysis["trial_must_be_1"])
    whp = np.mean(policy_analysis["trial_avg_wrong_and_high_prob"])

    print("Wrong with high prob {}".format(whp))
    print("Wrong {}".format(w))
    print("Right {}".format(r))
    print("Definite {}".format(d))

    print("m0 {}".format(m0))
    print("m1 {}".format(m1))


def stabilize_sim():
    env_name = "CartPole-v1"
    state = np.array([-1.462, -2.262, +1.798e-08, +1.399])
    # action = 0

    num_steps_right = 15
    action = 1

    for idx in range(num_steps_right):
        print(state)
        state = analyze_state_transition(env_name, state, action)
    # env.env.render()


def plot_state_history(states: np.ndarray) -> Tuple[matplotlib.figure.Figure, np.ndarray]:
    fig, axs = plotting.wrapped_subplot(2, 2, plot_scale=2.5)

    assert states.shape[0] > 0, "Should have at least one row"
    # x, x_dot, theta, theta_dot
    axs[0, 0].plot(states[:, 0])
    axs[0, 0].set_title("$x$")
    axs[0, 0].grid()

    axs[0, 1].plot(states[:, 1])
    axs[0, 1].set_title("$\\dot{x}$")
    axs[0, 1].grid()

    axs[1, 0].plot(states[:, 2])
    axs[1, 0].set_title("$\\theta$")
    axs[1, 0].grid()

    axs[1, 1].plot(states[:, 3])
    axs[1, 1].set_title("$\\dot{\\theta}$")
    axs[1, 1].grid()
    return fig, axs


def hist_states(states: np.ndarray) -> Tuple[matplotlib.figure.Figure, np.ndarray]:
    assert states.shape[0] > 0, "Should have at least one row"
    fig, axs = plt.subplots(2, 2)
    axs[0, 0].hist(states[:, 0])
    axs[0, 0].set_title("x")

    axs[0, 1].hist(states[:, 1])
    axs[0, 1].set_title("xdot")

    axs[1, 0].hist(states[:, 2])
    axs[1, 0].set_title("theta")

    axs[1, 1].hist(states[:, 3])
    axs[1, 1].set_title("thetadot")
    return fig, axs


def compute_velocity_limits_simple(par: Dict[str, Any]):
    """
    These limits are obviously going to be loose in practice, but
    amongst all loose bounds, they enjoy the property that:
      There is an initialization and a path of actions which actually
      deliver them.
    """
    # Push left until episode ends
    #
    # Start with:
    #   Cart as right as possible: x = +2.39
    #   Cart still: xdot = 0
    #   Pole as left as possible: -.20
    #   Pole still: thetadot = 0

    fig_format = par["fig_format"]
    paths = par["paths"]
    plots_path = paths["plots"]

    extreme_initial_state = np.array([+2.39, 0.0, -.20, 0.0])

    # idx = 0; extreme_initial_state = try_initial_states[idx]
    env_name = "CartPole-v1"
    env = gym.make(env_name)
    env.seed(0)
    env.reset()
    env.env.state = extreme_initial_state
    done = False

    episode_states = []
    episode_actions = []
    num_steps = 0
    while not done:
        action = 0
        # action = fixed_action
        state, reward, done, info = env.step(action)
        episode_states.append(state)
        episode_actions.append(action)
        num_steps += 1

    if "pgf" == fig_format:
        plotting.initialise_pgf_plots("pdflatex", "serif")

    states = np.array(episode_states)
    fig, axs = plot_state_history(states)

    ident = "state_dynamics"
    filepath = plots_path
    fig_path = plotting.smart_save_fig(fig,
                                       ident,
                                       fig_format,
                                       filepath)
    logger.info("Plot at {}".format(fig_path))

    x_limit = 2.4
    theta_limit = 12 * np.pi * 2 / 360

    print("Starting point: {}".format(extreme_initial_state))
    print("Terminal x: {:.4f} (limit: {:.4f})".format(states[-1, 0], x_limit))
    print("Terminal xdot: {:.4f}".format(states[-1, 1]))
    print("Terminal theta: {:.4f} (limit: {:.4f})".format(states[-1, 2], theta_limit))
    print("Terminal thetadot: {:.4f}".format(states[-1, 3]))


def set_par() -> Dict[str, Any]:
    paths = path_config.get_paths()

    num_hidden = 12
    # num_hidden = 15
    # num_hidden = 16

    # fig_fmt = "png"
    fig_fmt = "pgf"
    par = {
        "fig_format": fig_fmt,
        "paths": paths,
        "num_hidden": num_hidden
    }
    return par


if __name__ == "__main__":
    # compute_velocity_limits()

    # seeds = settings.set_seeds()

    seed = 60
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    par = set_par()

    # compute_velocity_limits_simple(par)
    paths = par["paths"]
    cache_dir = paths["cached_calculations"]
    num_hidden = par["num_hidden"]

    env_name = "CartPole-v1"

    episodes = 1000
    # episodes = 50

    experiment_name = "baseline"
    # experiment_name = "rules"

    if experiment_name == "baseline":
        use_polytope_rules = False
    elif experiment_name == "rules":
        use_polytope_rules = True
    else:
        raise ValueError("")

    if use_polytope_rules:
        polytope_rules = get_polytope_rules()
    else:
        polytope_rules = []

    print("Use polytope rules: {}".format(use_polytope_rules))
    calc_fun = train_policy
    calc_args = (env_name, num_hidden, episodes, polytope_rules)

    force_regeneraton = False
    # force_regeneraton = True

    cache_training = True
    # cache_training = False
    if cache_training:
        calc_kwargs = {}
        training_results = caching.cached_calc(
            cache_dir, calc_fun, calc_args, calc_kwargs, force_regeneraton
        )
    else:
        training_results = calc_fun(*calc_args)

    policy = training_results["policy"]

    do_simple_volume_check = False
    # do_simple_volume_check = True
    if do_simple_volume_check:
        n_points = 50000
        unifs = np.random.uniform(low=0.0, high=+1.0, size=(n_points, 4))
        inputs = STATE_LOWER + unifs * (STATE_UPPER - STATE_LOWER)

        inputs_torch = torch.from_numpy(inputs).float()
        outputs_torch = policy(inputs_torch)

        d1_rows = np.logical_and(inputs[:, 0] <= 0,
                  np.logical_and(inputs[:, 1] <= 0,
                  np.logical_and(inputs[:, 2] >= 0,
                                 inputs[:, 3] >= 0)))
        d0_rows = np.logical_and(inputs[:, 0] >= 0,
                  np.logical_and(inputs[:, 1] >= 0,
                  np.logical_and(inputs[:, 2] <= 0,
                                 inputs[:, 3] <= 0)))

        assert abs(np.mean(d1_rows) - 1 / 16) < .001 * np.sqrt(n_points)
        assert abs(np.mean(d0_rows) - 1 / 16) < .001 * np.sqrt(n_points)

        output_np = outputs_torch.detach().numpy()
        action = np.argmax(output_np, axis=1)

        d1_error_rate = np.mean(0 == action[d1_rows]) / 16
        d0_error_rate = np.mean(1 == action[d0_rows]) / 16

        logger.info("Error rate (classed 1, should be 0) {:.4f}".format(d0_error_rate))
        logger.info("Error rate (classed 0, should be 1) {:.4f}".format(d1_error_rate))

    do_policy_analysis = False
    # do_policy_analysis = True
    if do_policy_analysis:
        policy_analysis = analyze_policy(env_name, policy, polytope_rules)
        present_policy_analysis(env_name, policy, policy_analysis, polytope_rules)

    if False:
        present_fitting_results(training_results, experiment_name)

    input_layer_bounds = build_input_layer_bounds(training_results)
    inversion_results = inversion_analysis(policy, input_layer_bounds)
    print_inversion_results(par, inversion_results)

    exhibit_strange_behavior(env_name, policy, inversion_results)

