import numpy as np
from scipy.optimize import linprog
import matplotlib.pyplot as plt


def calculate_transition_function(env):
    """ Calculate the true transition function for the environment """
    om = {}
    for i, key in enumerate(env.sas_ind.keys()):
        if env.sas_ind[key] != -1:
            layer = env.get_layer(key[0]) + 1
            om[key] = 1 / (env.sas_count[layer] - env.sas_count[layer - 1])

    tf = {}
    un_induced_policy = {}  # Unnormalized induced policy
    induced_policy = {}  # Induced policy from occupancy measure
    state_action_pairs = env.sa_ind.keys()

    for state_action in state_action_pairs:
        keys = [
            triple for triple in env.sas_ind.keys() if state_action[0:2] == triple[0:2]
        ]  # (x,a,x') for every valid triple starting with specific (x,a,*)
        prob = sum([om[triple] for triple in keys])  # sum_(x':(x,a,x')) q(x,a,x')
        un_induced_policy[state_action] = prob

    for triple in env.sas_ind.keys():
        induced_policy[triple] = om[triple] / un_induced_policy[triple[0:2]]
    tf = induced_policy

    if max([len(x) for x in env.layers]) == 2:
        for state_action in state_action_pairs:
            keys = [
                triple
                for triple in env.sas_ind.keys()
                if state_action[0:2] == triple[0:2]
            ]

            # Calculate shift amount based on action
            action = keys[0][1]
            a2 = len(env.actions[0]) // 2
            if len(keys) == 1:
                continue
            if keys[0][0] == 1:
                action = env.actions[1].index(action)
            if action == a2:
                shift_amount = 0
            elif action < a2:
                t = action / a2
                shift_amount = -0.4 * (1 - t) ** 2
            elif action > a2:
                t = (action - a2) / a2
                shift_amount = 0.4 * t**2

            tf[keys[1]] += shift_amount
            tf[keys[0]] -= shift_amount

    return tf


def calculate_optimal_q(
    rewards_coefficients, lse, rse, g_const_lse, g_const_rse, sa_ind, sas_count, sas_ind
):
    """
    Calculates the optimal q for the environment
    """
    # Calculate optimal q via linear programming
    # Obj function coefficients
    coeff_repeat = [
        len([triple for triple in sas_ind if triple[0:2] == pair]) for pair in sa_ind
    ]
    coefficients = np.repeat(rewards_coefficients, coeff_repeat)

    # Calculate q_star
    opt = linprog(
        c=-coefficients,
        A_eq=lse,
        b_eq=rse,
        A_ub=g_const_lse.T,
        b_ub=g_const_rse,
        bounds=[(1e-10, 1)] * sas_count[-1],
    )

    if opt.success is False:
        print("Optimization failed")
        print(opt.message)
        print(
            "Check the parametr constraints_difficulty in env_config.json and try to lower it"
        )
        raise ValueError("Optimization failed")

    print("Calculating optimal q - Optimization successful: ", opt.success)

    optimal_q_dict = dict(zip(sas_ind.keys(), opt.x))
    optimal_q_pairs = np.array(
        [
            sum(
                [
                    optimal_q_dict[triple]
                    for triple in optimal_q_dict.keys()
                    if triple[0:2] == key
                ]
            )
            for key in sa_ind
        ]
    )
    return opt.x, optimal_q_pairs


def generate_valid_occupancy_measure_constraints(
    layers, sas_count, sas_ind, transition_function
):
    """
    Generate valid occupancy measure constraints.

    An occupancy measure is valid if it satisfies the following constraints:
    1. sum_(x,a,x' in l_k) q(x,a,x') = 1 for every layer l_k, k = 0...L-1
    2. flux constraint sum_(x' in l_k-1) q(x',a,x) = sum_(x' in l_k+1) q(x,a,x') for every x in l_k
    3. transition function P^q = P (must induce the same transition function)
    """

    tf = transition_function

    # First constraint (layers sum to one)
    layer_constraints_lse = []
    num_layers = len(sas_count[:-1])
    for i in range(num_layers):
        layer_constraints = np.zeros(sas_count[-1], dtype=float)
        layer_constraints[sas_count[i] : sas_count[i + 1]] = 1
        layer_constraints_lse.append(layer_constraints)
    layer_constraints_rse = np.ones(num_layers, dtype=float)

    # Second constraint (flux constraint)
    flux_const_lse = []
    for i, layer in enumerate(layers[1:-1]):
        for state in layer:
            node_lse = np.zeros(sas_count[-1], dtype=float)
            in_flux = [triple for triple in sas_ind if triple[2] == state]
            out_flux = [triple for triple in sas_ind if triple[0] == state]
            indexing_in_flux = [sas_ind[x] for x in in_flux]
            indexing_out_flux = [sas_ind[x] for x in out_flux]
            node_lse[indexing_in_flux] = 1
            node_lse[indexing_out_flux] = -1
            flux_const_lse.append(node_lse)
    flux_const_rse = np.zeros(len(flux_const_lse), dtype=int)

    # Third function constraints (transition function)
    tf_const_lse = []
    unique_values = []
    for key in sas_ind:
        # if key[0:2] in unique_values:
        #    continue
        tf_lse = np.zeros(sas_count[-1], dtype=float)
        right = [triple for triple in sas_ind if triple[0:2] == key[0:2]]
        if len(right) == 1:
            continue  # No need to add constraint if there is only one possible transition
        indexing_left = sas_ind[key]
        indexing_right = [sas_ind[x] for x in right]
        tf_lse[indexing_right] = tf[key]
        tf_lse[indexing_left] = tf[key] - 1
        tf_const_lse.append(tf_lse)
        unique_values.append(key[0:2])
    tf_const_rse = np.zeros(len(tf_const_lse), dtype=int)

    # Concatenate all constraints
    lse = layer_constraints_lse + flux_const_lse + tf_const_lse
    rse = np.concatenate((layer_constraints_rse, flux_const_rse, tf_const_rse))

    return lse, rse


def generate_constraints_satisfaction_cons(
    constraints_mean, n_constraints, sa_ind, sas_ind
):
    """
    Generate constraints satisfaction constraints.
    """
    # Obj function coefficients
    coeff_repeat = [
        len([triple for triple in sas_ind if triple[0:2] == pair]) for pair in sa_ind
    ]
    g_const_lse = np.repeat(constraints_mean, coeff_repeat, axis=0)

    g_const_rse = np.zeros(n_constraints, dtype=int)
    return g_const_lse, g_const_rse


def plot_cumulative(y_vals_regret, y_vals_constraints, save, save_name, n_constraints):
    """
    Plot the results of the algorithm
    """
    x = np.array(range(len(y_vals_regret[0])))
    ci_regret = 1.96 * np.std(y_vals_regret, axis=0) / np.sqrt(len(y_vals_regret))
    y_mean_regret = np.mean(y_vals_regret, axis=0)

    ci_constraints = (
        1.96 * np.std(y_vals_constraints, axis=0) / np.sqrt(len(y_vals_constraints))
    )
    y_mean_constraints = np.mean(y_vals_constraints, axis=0)

    fig, ax = plt.subplots(
        1,
        1 + 2 * int(n_constraints != 0),
        figsize=(5 + 15 * int(n_constraints != 0), 5),
        squeeze=False,
    )
    ax[0, 0].plot(x, y_mean_regret, "-r", label=r"$R_{t}$")
    ax[0, 0].legend(loc="upper left")
    ax[0, 0].set_xlabel("t")
    ax[0, 0].set_title("Cumulative Regret")
    ax[0, 0].fill_between(
        x,
        y_mean_regret - ci_regret,
        y_mean_regret + ci_regret,
        color="r",
        alpha=0.1,
    )

    if n_constraints > 0:
        ax[0, 1].plot(x, y_mean_constraints, "-b", label=r"$V_{t}$")
        ax[0, 1].fill_between(
            x,
            y_mean_constraints - ci_constraints,
            y_mean_constraints + ci_constraints,
            color="b",
            alpha=0.1,
        )
        ax[0, 1].legend(loc="upper right")
        ax[0, 1].set_title("Cumulative Constraints Violation")
        ax[0, 1].set_xlabel("t")

        # Plot the sum of constraints and regrets in ax[0,2]
        ax[0, 2].plot(
            x,
            y_mean_regret + y_mean_constraints,
            "-g",
            label=r"$R_{t} + V_{t}$",
        )
        ax[0, 2].fill_between(
            x,
            y_mean_regret + y_mean_constraints - ci_regret - ci_constraints,
            y_mean_regret + y_mean_constraints + ci_regret + ci_constraints,
            color="g",
            alpha=0.1,
        )
        ax[0, 2].legend(loc="upper left")
        ax[0, 2].set_title("Regret + Constraints Violation")
        ax[0, 2].set_xlabel("t")

    if save:
        plt.savefig(save_name + "plot_0_regret_constraints.png", dpi=600)

    # Save all the single plots
    fig, ax = plt.subplots(
        1,
        1,
        figsize=(5, 5),
        squeeze=False,
    )
    ax[0, 0].plot(x, y_mean_regret, "-r", label=r"$R_{t}$")
    ax[0, 0].legend(loc="upper left")
    ax[0, 0].set_xlabel("t")
    ax[0, 0].set_title("Cumulative Regret")
    ax[0, 0].fill_between(
        x,
        y_mean_regret - ci_regret,
        y_mean_regret + ci_regret,
        color="r",
        alpha=0.1,
    )
    if save:
        plt.savefig(save_name + "plot_1_regret.png", dpi=600)

    fig, ax = plt.subplots(
        1,
        1,
        figsize=(5, 5),
        squeeze=False,
    )
    if n_constraints > 0:
        ax[0, 0].plot(x, y_mean_constraints, "-b", label=r"$V_{t}$")
        ax[0, 0].fill_between(
            x,
            y_mean_constraints - ci_constraints,
            y_mean_constraints + ci_constraints,
            color="b",
            alpha=0.1,
        )
        ax[0, 0].legend(loc="upper right")
        ax[0, 0].set_title("Cumulative Constraints Violation")
        ax[0, 0].set_xlabel("t")
        if save:
            plt.savefig(save_name + "plot_2_constraints.png", dpi=600)

    fig, ax = plt.subplots(
        1,
        1,
        figsize=(5, 5),
        squeeze=False,
    )



def plot_lagrangian(y_vals_lagrangian, n_constraints, save, save_name):
    """
    Plot the results of the algorithm
    """
    x = np.array(range(len(y_vals_lagrangian[0])))
    ci_lagrangian = (
        1.96 * np.std(y_vals_lagrangian, axis=0) / np.sqrt(len(y_vals_lagrangian))
    )
    y_mean_lagrangian = np.mean(y_vals_lagrangian, axis=0)

    fig, ax = plt.subplots(1, n_constraints, figsize=(12, 5))
    # Ensure ax is always a list
    ax = np.array(ax).flatten()

    for i in range(n_constraints):
        ax[i].plot(x, y_mean_lagrangian[:, i], "-g", label=r"$\lambda_{}$".format(i))
        ax[i].legend(loc="upper left")
        ax[i].fill_between(
            x,
            y_mean_lagrangian[:, i] - ci_lagrangian[:, i],
            y_mean_lagrangian[:, i] + ci_lagrangian[:, i],
            color="g",
            alpha=0.1,
        )

    if save:
        plt.savefig(save_name + "plot_3_lagrangian.png", dpi=600)


def plot_no_constraints(save, save_name):
    """ Plot (in the case where there are no constraints)"""
    # Create a blank plot
    fig, ax = plt.subplots(figsize=(12, 5))

    # Set the title to "No Constraints"
    ax.set_title("No Constraints")

    # Remove the x and y axis ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Add the text "No Constraints" in the middle of the plot
    ax.text(0.5, 0.5, "No Constraints", ha="center", va="center", fontsize=20)

    if save:
        plt.savefig(save_name + "plot_3_lagrangian.png", dpi=600)


def plot_reward(y_vals_rewards, save, save_name):
    """
    Plot the results of the algorithm
    """
    x = np.array(range(len(y_vals_rewards[0])))
    ci_rewards = 1.96 * np.std(y_vals_rewards, axis=0) / np.sqrt(len(y_vals_rewards))
    y_mean_rewards = np.mean(y_vals_rewards, axis=0)

    fig, ax = plt.subplots(
        1,
        1,
        squeeze=False,
    )
    ax[0, 0].plot(
        x, y_mean_rewards, "-b", label=r"$\frac{\sum_{t=1}^{t^*}r_{t}q_{t}}{t^*}$"
    )
    ax[0, 0].legend(loc="upper left")
    ax[0, 0].set_xlabel("t")
    ax[0, 0].set_title("Mean Reward")
    ax[0, 0].fill_between(
        x,
        y_mean_rewards - ci_rewards,
        y_mean_rewards + ci_rewards,
        color="b",
        alpha=0.1,
    )

    if save:
        plt.savefig(save_name + "plot_4_reward.png", dpi=600)
