from __future__ import annotations

import ast
import json
import typing
import numpy as np
from scipy.stats import bernoulli
from core.environment import Environment
from simulations.utils import (
    generate_constraints_satisfaction_cons,
    generate_valid_occupancy_measure_constraints,
    calculate_optimal_q,
    calculate_transition_function,
)

if typing.TYPE_CHECKING:
    from typing import Dict, Tuple, List
    from core.transition_function import TransitionFunction


class DataEnhancedEnvironment(Environment):
    """
    A class used to represent the environment of the Markov Decision Process
    in the case of a simulation with constraints and rewards sampled using real data
    """

    def __init__(
        self,
        time_horizon: int,
        constraints_difficulty=0.004,
        n_shifts=0,
        path="../config/data_env_config.json",
    ):
        # Initializes parent environment
        super().__init__(data=True, path=path)

        (
            self.transition_function,
            self.conversion_rates,
            self.price_scaling_coeffs,
            self.fixed_reward_constant,
            self.coeff_post,
        ) = self.read_config(path)
        self.constraints_difficulty = constraints_difficulty
        self.constraints_list, self.rewards_list = [], []
        self.n_shifts = n_shifts
        self.time_horizon = time_horizon

        # Calculates mean of constraints and conversion rates
        # Conversion rates are used to calculate the mean of the rewards
        (
            self.constraints_mean,
            self.fixed_reward,
        ) = self._calculate_mean()

        (
            self.conversion_rate_arrays,
            self.constraint_mean_arrays,
        ) = self._calculate_coeff_arrays(
            self.conversion_rates,
            self.coeff_post,
        )
        (
            self.optimal_q_arrays,
            self.optimal_q_pairs_arrays,
        ) = self._calculate_optimal_q_arrays()

    def read_config(self, path):
        """
        Reads the configuration file and returns the variables
        """

        with open(path, "r", encoding="utf-8") as f:
            variables = json.load(f)
            transition_function = variables["transition_function"]
            conversion_rates = variables["conversion_rates"]
            mean_revenues = variables["mean_revenues"]
            fixed_reward_constant = variables["fixed_reward"]
            coeff_post = variables["coeff_post"]

            # Convert transition function keys from string to tuples
            transition_function = {
                ast.literal_eval(k): v for k, v in transition_function.items()
            }

            # Flatten the arrays
            conversion_rates = np.concatenate(conversion_rates).ravel()
            price_scaling_coeffs = np.concatenate(mean_revenues).ravel()
            coeff_post = np.concatenate(coeff_post).ravel()

        return (
            transition_function,
            conversion_rates,
            price_scaling_coeffs,
            fixed_reward_constant,
            coeff_post,
        )

    def _calculate_mean(self):
        """
        Calculates the mean of the constraints and rewards for every (s,a) pair
        """

        # Calculate the vector of fixed rewards
        fixed_reward = np.zeros(self.sa_count[-1])
        start_index = min([self.sa_ind[(1, x)] for x in self.actions[1]])
        stop_index = max([self.sa_ind[(2, x)] for x in self.actions[2]])
        fixed_reward[start_index : stop_index + 1] = self.fixed_reward_constant

        final_page_index = min([self.sa_ind[(4, x)] for x in self.actions[4]])
        fixed_reward[final_page_index] = self.fixed_reward_constant

        # Calculate the mean of the constraints
        conversion_rates_required = np.zeros(self.sa_count[-1])
        conversion_rates_required[self.sa_count[0] : self.sa_count[1]] = (
            -self.conversion_rates[self.sa_count[0] : self.sa_count[1]]
            + self.constraints_difficulty
        )
        constraint_value = conversion_rates_required

        # Duplicating the constraints
        constraint_value = np.repeat(
            constraint_value[:, np.newaxis], self.n_constraints, axis=1
        )

        return constraint_value, fixed_reward

    def _calculate_coeff_arrays(self, coeff_pre, coeff_post):
        """
        This function handles the case in which the number of shifts is greater than 0.
        In this case the coefficients are gradually shifted for every shift, interpolating
        between the initial coefficients (coeff_pre) and final coefficients (coeff_post).
        """
        lengths = [0] + [len(action_list) for action_list in self.actions]
        lengths = np.cumsum(lengths)

        coeff_arrays = np.empty((self.n_shifts + 1, len(coeff_pre)), dtype=float)
        for i in range(
            1, self.n_shifts + 1
        ):  # If n_shifts = 0, only the initial coefficients are used
            coeff_arrays[i] = coeff_pre + (i) * (coeff_post - coeff_pre) / self.n_shifts
        coeff_arrays[0] = coeff_pre

        constraint_mean_arrays = np.empty(
            (self.n_shifts + 1, self.sa_count[-1], self.n_constraints), dtype=float
        )

        for i in range(
            0, self.n_shifts + 1
        ):  # If n_shifts = 0, only the initial coefficients are used
            # Calculate the mean of the constraints
            conversion_rates_required = np.zeros(self.sa_count[-1])
            conversion_rates_required[self.sa_count[0] : self.sa_count[1]] = (
                -coeff_arrays[i][self.sa_count[0] : self.sa_count[1]]
                + self.constraints_difficulty
            )
            constraint_value = conversion_rates_required

            # Duplicating the constraints
            constraint_value = np.repeat(
                constraint_value[:, np.newaxis], self.n_constraints, axis=1
            )
            constraint_mean_arrays[i] = constraint_value

        return coeff_arrays, constraint_mean_arrays

    def _calculate_optimal_q_arrays(self):
        """
        Calculates the optimal q for every (s,a) pair and (s,a,s) triple, for every shift
        """
        optimal_q_arrays = np.empty(
            (self.n_shifts + 1, self.sas_count[-1]), dtype=float
        )
        optimal_q_pairs_arrays = np.empty(
            (self.n_shifts + 1, self.sa_count[-1]), dtype=float
        )

        lse, rse = generate_valid_occupancy_measure_constraints(
            self.layers, self.sas_count, self.sas_ind, self.transition_function
        )

        for i in range(0, self.n_shifts + 1):

            rewards_mean = (
                self.conversion_rate_arrays[i] * self.price_scaling_coeffs
                + self.fixed_reward
            )

            (
                g_const_lse,
                g_const_rse,
            ) = generate_constraints_satisfaction_cons(
                self.constraint_mean_arrays[i],
                self.n_constraints,
                self.sa_ind,
                self.sas_ind,
            )

            optimal_q_arrays[i], optimal_q_pairs_arrays[i] = calculate_optimal_q(
                rewards_mean,
                lse,
                rse,
                g_const_lse,
                g_const_rse,
                self.sa_ind,
                self.sas_count,
                self.sas_ind,
            )

        return optimal_q_arrays, optimal_q_pairs_arrays

    def _sample(self, time_horizon: int):
        """Generates constraints and rewards stochastically for every (s,a) pair sampling from
          a beta distribution with parameter a depending on action, for every t:0->time horizon.

        Returns
        -------
        np.array[np.2Darray[float]], np.array[np.array[float]]
            Constraints and rewards for every (s,a) pair
        """
        constraints_list = np.empty(
            (time_horizon, self.sa_count[-1], self.n_constraints)
        )
        rewards_list = np.empty((time_horizon, self.sa_count[-1]))

        for t in range(time_horizon):
            rewards = np.empty(self.sa_count[-1])
            constraints_matrix = np.zeros((self.sa_count[-1], self.n_constraints))

            # Draw the action samples (sold or not) using the means of the current shift
            action_sample = bernoulli.rvs(
                self.get_conversion_rates(t), size=self.sa_count[-1]
            )
            # Calculate rewards using: price * I(sold)
            rewards = self.price_scaling_coeffs * action_sample + self.fixed_reward

            # Calculate constraints using: -conversion_rate + difficulty
            # (satisfied if sold, not satisfied if not sold)
            constraints_matrix = np.zeros(self.sa_count[-1])
            constraints_matrix[self.sa_count[0] : self.sa_count[1]] = (
                -action_sample[self.sa_count[0] : self.sa_count[1]]
                + self.constraints_difficulty
            )
            constraints_matrix = np.repeat(
                constraints_matrix[:, np.newaxis], self.n_constraints, axis=1
            )

            rewards_list[t, :] = rewards
            if self.n_constraints != 0:
                constraints_list[t, :, :] = constraints_matrix
            else:
                constraints_list[t] = None

        return constraints_list, rewards_list

    def get_conversion_rates(self, t: int):
        """
        Returns the conversion rates for every (s,a) pair at time t, taking into account the number of shifts
        """
        divisor = np.floor(self.time_horizon / (self.n_shifts + 1))
        return self.conversion_rate_arrays[
            min(int(np.floor(t / divisor)), self.n_shifts)
        ]

    def get_constraint_mean(self, t: int):
        """
        Returns the constraint mean for every (s,a) pair at time t, taking into account the number of shifts
        """
        divisor = np.floor(self.time_horizon / (self.n_shifts + 1))
        return self.constraint_mean_arrays[
            min(int(np.floor(t / divisor)), self.n_shifts)
        ]

    def get_optimal_q(self, t: int):
        """
        Returns the optimal q for every (s,a,s) triple at time t, taking into account the number of shifts
        """
        divisor = np.floor(self.time_horizon / (self.n_shifts + 1))
        return self.optimal_q_arrays[min(int(np.floor(t / divisor)), self.n_shifts)]

    def get_optimal_q_pairs(self, t: int):
        """
        Returns the optimal q for every (s,a) pair at time t, taking into account the number of shifts
        """
        divisor = np.floor(self.time_horizon / (self.n_shifts + 1))
        return self.optimal_q_pairs_arrays[
            min(int(np.floor(t / divisor)), self.n_shifts)
        ]

    def get_reward_mean(self, t: int):
        """
        Returns the reward mean for every (s,a) pair at time t, taking into account the number of shifts
        """
        return (
            self.get_conversion_rates(t) * self.price_scaling_coeffs + self.fixed_reward
        )

    def reset(self, time_horizon: int, seed: int):
        """
        Resets the environment and returns the initial constraints and rewards
        """
        np.random.seed(seed)
        self.constraints_list, self.rewards_list = self._sample(time_horizon)
        return self.constraints_list, self.rewards_list

    def play_policy(
        self, policy: Dict[Tuple[int, float], int], time_step: int
    ) -> Tuple[List, List, List]:
        """Plays policy and returns result for given episode

        Parameters
        ----------
        Policy : Dict[Tuple[int,float],int]
            Policy to be played

        Returns
        -------
        Tuple[List,List,List]
            Path, losses, constraints of the episode
        """
        path = []
        rewards = self.rewards_list[time_step]
        constraints = self.constraints_list[time_step]

        state = self.layers[0][0]

        while state != self.layers[-1][-1]:
            # Sample policy
            sa_pairs = {
                pair: policy[pair] for pair in policy.keys() if pair[0] == state
            }
            keys = list(sa_pairs)
            index = np.random.choice(len(keys), 1, p=list(sa_pairs.values()))[0]
            played_pair = keys[index]

            # Sample transition
            sas_triples = {
                triple: self.transition_function[triple]
                for triple in self.transition_function.keys()
                if triple[0:2] == played_pair
            }

            keys = list(sas_triples)
            index = np.random.choice(len(keys), 1, p=list(sas_triples.values()))[0]
            selected_transition = keys[index]

            if state == 0 and rewards[self.sa_ind[played_pair]] != 0:
                selected_transition = (
                    selected_transition[0],
                    selected_transition[1],
                    1,
                )
            elif state == 0 and rewards[self.sa_ind[played_pair]] == 0:
                overall_prob_transition = (
                    self.transition_function[
                        (selected_transition[0], selected_transition[1], 2)
                    ]
                    + self.transition_function[
                        (selected_transition[0], selected_transition[1], 3)
                    ]
                )
                prob_transition_two = (
                    self.transition_function[
                        (selected_transition[0], selected_transition[1], 2)
                    ]
                    / overall_prob_transition
                )
                prob_transition_three = (
                    self.transition_function[
                        (selected_transition[0], selected_transition[1], 3)
                    ]
                    / overall_prob_transition
                )
                assert (
                    prob_transition_two + prob_transition_three == 1
                ), "Probabilities do not sum to 1"

                if np.random.binomial(1, prob_transition_two):
                    selected_transition = (
                        selected_transition[0],
                        selected_transition[1],
                        2,
                    )
                else:
                    selected_transition = (
                        selected_transition[0],
                        selected_transition[1],
                        3,
                    )

            # Store values
            path.append(selected_transition)
            state = selected_transition[2]

        return path, rewards, constraints