import numpy as np
import random

from copy import copy, deepcopy
from pathlib import Path
from scipy.special import logsumexp



class ImplicitMirrorDescentValueIteration:
    # def __init__(self, size_of_state_space, size_of_action_space):
    def __init__(
        self,
        env,
        gamma: float = 0.99,
        alpha: float = None,
        beta: float = None,
        tau: float = None,
        kappa: float = None,
        num_pe_backups: int = np.inf,
        backup_error_magnitude: float = 0,
        seed: int = 0,
        initial_psi: str = 'zero',
        eps: float = 1e-8,
    ):
        self.random = np.random.default_rng(seed)

        self.env = env
        self.gamma = gamma

        if alpha is None or beta is None:
            assert tau is not None and kappa is not None
            self.tau = tau
            self.kappa = kappa
            self.alpha = tau + kappa
            self.beta = kappa / (tau + kappa)
        else:
            assert alpha is not None and beta is not None
            self.tau = alpha * (1 - beta)
            self.kappa = alpha * beta
            self.alpha = alpha
            self.beta = beta
        # else:
        #     raise ValueError
        print(f'(tau, kappa)  = ({self.tau}, {self.kappa})')
        print(f'(alpha, beta) = ({self.alpha}, {self.beta})')

        self.eps = eps
        self.backup_error_magnitude = backup_error_magnitude
        self.num_pe_backups = num_pe_backups
        self.initialization(initial_psi)

        # if verbose:
        print('initial policy = ', self.policy)
        print('initial value = ', self.psi)

        # self.curve_step = None
        # self.curve_error_max = None
        # self.curve_error_ave = None
        self.error_curve = None


    def initialization(self, initial_psi):

        if initial_psi == 'zero':
            self.psi = np.zeros((self.env.size_of_state_space, self.env.size_of_action_space))
        elif initial_psi == 'optimistic':
            max_V = (self.env.reward_map.max() + self.tau*np.log(self.env.size_of_action_space)) / (1 - self.gamma)
            self.psi = np.ones((self.env.size_of_state_space, self.env.size_of_action_space)) * max_V
        elif initial_psi == 'random':
            max_V = (self.env.reward_map.max() + self.tau*np.log(self.env.size_of_action_space)) / (1 - self.gamma)
            self.psi = self.random.uniform(-max_V, max_V, (self.env.size_of_state_space, self.env.size_of_action_space))
        else:
            raise ValueError
        self.policy = None
        # self.policy = np.zeros((self.env.size_of_state_space, self.env.size_of_action_space))
        #
        # """ uniform policy for allowed actions """
        # for s in range(self.env.size_of_state_space):
        #     if len(self.env.allowed_actions[s]) != 0:
        #         for a in self.env.allowed_actions[s]:
        #             self.policy[s,a] = 1.0 / len(self.env.allowed_actions[s])
        #     else:
        #         for a in range(0, self.env.size_of_action_space):
        #             self.policy[s,a] = 1.0 / self.env.size_of_action_space
        self.soft_policy_improvement()


    def soft_policy_improvement(self):
        detour = self.psi / self.alpha
        # print(f'SPI: logit = {detour}')
        detour = np.exp(detour - np.max(detour, axis=1, keepdims=True))
        # print(f'SPI: offsetted = {detour}')
        Z = np.sum(detour, axis=1, keepdims=True)
        policy_new = detour / Z
        # print(f'SPI: policy_new = {policy_new}')
        # for s in range(self.env.size_of_state_space):
        #     if len(self.env.allowed_actions[s]) == 0:
        #         for a in range(0, self.env.size_of_action_space):
        #             policy_new[s,a] = 1.0 / self.env.size_of_action_space

        if self.policy is not None:
            diff = np.max(np.abs(policy_new - self.policy))
        else:
            diff = None
        self.policy = policy_new
        return diff


    # def L(self, Q_s, coeff):
    #     return coeff * logsumexp(Q_s / coeff)
    @staticmethod
    def L(Q_s, coeff):
        return coeff * logsumexp(Q_s / coeff)


    def compute_V_opt_from_Q(self, Q, coeff):
        V = np.zeros(self.env.size_of_state_space)
        for s in range(self.env.size_of_state_space):
            # V[s] = self.L(Q[s,:], coeff)
            V[s] = self.__class__.L(Q[s,:], coeff)
        return V


    def backup_V_by_soft_Bellman(self, V_old, pi, coeff):
        # Q_old = self.psi
        # V_new = np.zeros((self.env.size_of_state_space, self.env.size_of_action_space))
        V_new = np.zeros(self.env.size_of_state_space)
        for s in range(self.env.size_of_state_space):
            # for a in self.env.allowed_actions[s]:
            for a in range(self.env.size_of_action_space):
                Q_sa = 0.0
                P = self.env.state_transition_law(s,a)
                if self.env.possible_next_states[s]:
                    for ns in self.env.possible_next_states[s]:
                        if P[ns] != 0.0:
                            Q_sa += P[ns] * (self.env.reward_function(s,a,ns) + self.gamma * V_old[ns])
                V_new[s] += pi[s,a] * (Q_sa - coeff * np.log(pi[s,a] + self.eps))

        diff = np.max(np.abs(V_new - V_old))
        return V_new, diff


    def backup_Q_by_soft_Bellman(self, Q_old, pi, coeff):
        # Q_old = self.psi
        Q_new = np.zeros((self.env.size_of_state_space, self.env.size_of_action_space))
        for s in range(self.env.size_of_state_space):
            # for a in self.env.allowed_actions[s]:
            for a in range(self.env.size_of_action_space):
                Q_sa = 0.0
                P = self.env.state_transition_law(s,a)
                if self.env.possible_next_states[s]:
                    for ns in self.env.possible_next_states[s]:
                        if P[ns] != 0.0:
                            V_ns = 0.0
                            # for na in self.env.allowed_actions[ns]:
                            for na in range(self.env.size_of_action_space):
                                V_ns += pi[ns,na] * (Q_old[ns,na] - coeff * np.log(pi[ns,na] + self.eps))
                            Q_sa += P[ns] * (self.env.reward_function(s,a,ns) + self.gamma * V_ns)
                Q_new[s,a] = Q_sa

        diff = np.max(np.abs(Q_new - Q_old))
        return Q_new, diff


    def backup_Q_by_soft_Bellman_opt(self, Q_old, coeff):
        # Q_old = self.psi
        Q_new = np.zeros((self.env.size_of_state_space,self.env.size_of_action_space))
        for s in range(self.env.size_of_state_space):
            # for a in self.env.allowed_actions[s]:
            for a in range(self.env.size_of_action_space):
                Q_sa = 0.0
                P = self.env.state_transition_law(s,a)
                if self.env.possible_next_states[s]:
                    for ns in self.env.possible_next_states[s]:
                        if P[ns] != 0.0:
                            # V_ns = self.L(Q_old[ns,:], coeff)
                            V_ns = self.__class__.L(Q_old[ns,:], coeff)
                            Q_sa += P[ns] * (self.env.reward_function(s,a,ns) + self.gamma * V_ns)
                Q_new[s,a] = Q_sa

        diff = np.max(np.abs(Q_new - Q_old))
        return Q_new, diff


    def backup_psi(self, psi_old):
        """ IMDVI """
        V = self.compute_V_opt_from_Q(psi_old, self.alpha)
        A = psi_old - V.reshape(-1, 1)
        TQ, _ = self.backup_Q_by_soft_Bellman_opt(psi_old, self.alpha)
        psi_new = TQ + self.beta * A

        diff = np.max(np.abs(psi_new - psi_old))
        return psi_new, diff


    def policy_evaluation(self, pi, max_diff, coeff, verbose):
        """
            TODO? It must be possible to write in terms of backup_Q_by_soft_Bellman
        """
        step_evaluation = 0

        V = np.zeros(self.env.size_of_state_space)
        # Q = np.zeros((self.env.size_of_state_space, self.env.size_of_action_space))
        while True:
            V, epsilon = self.backup_V_by_soft_Bellman(V, pi, coeff)
            step_evaluation += 1
            # if verbose and step_evaluation % 100 == 0:
            #     print('PE step: ', step_evaluation, epsilon)
            if epsilon < max_diff or step_evaluation >= self.num_pe_backups:
                break
        if verbose:
            print('PE step: ', step_evaluation, epsilon)

        return V

    def evaluate(self, pi, max_diff_value, verbose):
        V_pi_tau = self.policy_evaluation(pi, max_diff_value, coeff=self.tau, verbose=verbose)
        V_pi_alpha = self.policy_evaluation(pi, max_diff_value, coeff=self.alpha, verbose=verbose)
        return V_pi_tau, V_pi_alpha


    def fit(
        self,
        max_diff_value=0.001,
        # max_diff_policy=0.01,
        num_iterations=100,
        # record_value=False,
        record_curve=False,
        record_frequency=None,
        V_opt=None, psi_opt=None,
        verbose=False,
    ):

        if record_curve:
            # psi_opt = np.load(psi_opt_filename)
            if V_opt is None:
                assert psi_opt is not None
                V_opt = self.compute_V_opt_from_Q(psi_opt, self.alpha)
            V_pi = self.policy_evaluation(self.policy, max_diff_value, coeff=self.tau, verbose=verbose)
            # self.curve_step = np.zeros(0)
            # self.curve_error_max = np.ones(0) * np.max(np.abs(V_opt - V_pi.transpose()))
            # self.curve_error_ave = np.ones(0) * np.average(np.abs(V_opt - V_pi.transpose()))
            self.error_curve = np.array([[
                0,
                np.max(np.abs(V_opt - V_pi.transpose())),
                np.average(np.abs(V_opt - V_pi.transpose()))
            ]])

        # if record_value:
        # record_psi = []
        # record_V = []
        record_psi = [copy(self.psi)]
        record_V = [self.compute_V_opt_from_Q(self.psi, self.alpha)]

        step_improvement = 0
        self.soft_policy_improvement()
        # self.soft_policy_improvement()
        while True:
            step_improvement += 1

            if record_curve and step_improvement % record_frequency == 0:
                # print(self.curve_error_max)
                print(self.error_curve)

            self.psi, diff_psi = self.backup_psi(self.psi)

            if self.backup_error_magnitude != 0.0:
                error = self.random.normal(np.zeros(self.psi.shape), self.backup_error_magnitude)
                self.psi += error
            diff_policy = self.soft_policy_improvement()

            if verbose and step_improvement % record_frequency == 0:
                print(f'PI step: {step_improvement}, {diff_psi=}, {diff_policy=}')

            # if record_value and step_improvement % record_frequency == 0:
            if step_improvement % record_frequency == 0:
                record_psi.append(copy(self.psi))
                record_V.append(self.compute_V_opt_from_Q(self.psi, self.alpha))
                # record_V.append(compute_V_opt_from_Q(self.psi, self.tau))

            if record_curve and step_improvement % record_frequency == 0:
                # V = self.compute_V_opt_from_Q(self.psi)
                # print(value_opt.shape)
                # print(V.shape)
                V_pi = self.policy_evaluation(self.policy, max_diff_value, coeff=self.tau, verbose=verbose)
                # self.curve_step = np.append(self.curve_step, step_improvement)
                # self.curve_error_max = np.append(self.curve_error_max, np.max(np.abs(V_opt - V_pi.transpose())))
                # self.curve_error_ave = np.append(self.curve_error_ave, np.average(np.abs(V_opt - V_pi.transpose())))
                self.error_curve = np.append(self.error_curve, [[
                    step_improvement,
                    np.max(np.abs(V_opt - V_pi.transpose())),
                    np.average(np.abs(V_opt - V_pi.transpose()))
                ]], axis=0)

            # if epsilon_policy < max_diff_policy:
            #     break
            # if self.backup_error_magnitude != 0.0 and step_improvement >= num_iterations:
            #     break
            if step_improvement >= num_iterations:
                break

        if record_curve:
            # self.curve_error_max = np.stack([self.curve_step, self.curve_error_max])
            # self.curve_error_ave = np.stack([self.curve_step, self.curve_error_ave])
            # print('self.curve_error_max')
            # print(self.curve_error_max)
            # print('self.curve_error_ave')
            # print(self.curve_error_ave)
            print(self.error_curve)
            # if self.which_value == 'V':
            #     V = self.psi
            # elif self.which_value == 'Q':
            #     V = self.compute_V_opt_from_Q(self.psi)
            # print('value_opt')
            # print(value_opt)
            # print('V')
            # print(V)
        print('min value = ', np.min(self.psi))
        print('max value = ', np.max(self.psi))

        # if record_value:
        return record_psi, record_V
        # else:
        #     return None, None


    def get_V_2D(self, V=None):
        """ TODO; tau? really? """
        if V is None:
            V = self.compute_V_opt_from_Q(self.psi, self.alpha)
        V_2D = np.zeros((self.env.grid_height, self.env.grid_width))
        for x in range(0, self.env.grid_width):
            for y in range(0, self.env.grid_height):
                s = self.env.xytos(x,y)
                V_2D[y,x] = V[s]
        return V_2D


    def get_psi(self):
        return self.psi


    def get_policy(self):
        return self.policy


    def get_curve_error_max(self):
        return self.curve_error_max


    def get_curve_error_ave(self):
        return self.curve_error_ave


    def get_error_curve(self):
        return self.error_curve



class BoundedAdvantageLearning(ImplicitMirrorDescentValueIteration):
    def __init__(
        self,
        env,
        gamma: float = 0.99,
        alpha: float = None,
        beta: float = None,
        tau: float = None,
        kappa: float = None,
        num_pe_backups: int = np.inf,
        backup_error_magnitude: float = 0,
        seed: int = 0,
        initial_psi: str = 'zero',
        bound_f: str = 'identity',
        bound_g: str = 'identity',
    ):
        super().__init__(
            env=env, gamma=gamma, alpha=alpha, beta=beta, tau=tau, kappa=kappa,
            num_pe_backups=num_pe_backups, backup_error_magnitude=backup_error_magnitude,
            seed=seed, initial_psi=initial_psi)

        if bound_f == 'identity':
            self.bound_f = lambda x: x
        elif bound_f == 'nclip':
            self.bound_f = lambda x: np.clip(x, -1., 1.)
        elif bound_f == 'rclip':
            self.bound_f = lambda x: np.clip(x/10, -1., 1.)
        elif bound_f == 'ntanh':
            self.bound_f = lambda x: np.tanh(x)
        elif bound_f == 'rtanh':
            self.bound_f = lambda x: np.tanh(x/10)
        else:
            raise ValueError
        if bound_g == 'identity':
            self.bound_g = lambda x: x
        elif bound_g == 'nclip':
            self.bound_g = lambda x: np.clip(x, -1., 1.)
        elif bound_g == 'rclip':
            self.bound_g = lambda x: np.clip(x/10, -1., 1.)
        elif bound_g == 'ntanh':
            self.bound_g = lambda x: np.tanh(x)
        elif bound_g == 'rtanh':
            self.bound_g = lambda x: np.tanh(x/10)            
        else:
            raise ValueError


    def backup_Q_by_soft_Bellman_opt(self, Q_old, A):

        Q_new = np.zeros((self.env.size_of_state_space,self.env.size_of_action_space))
        for s in range(self.env.size_of_state_space):
            # for a in self.env.allowed_actions[s]:
            for a in range(self.env.size_of_action_space):
                Q_sa = 0.0
                P = self.env.state_transition_law(s,a)
                if self.env.possible_next_states[s]:
                    for ns in self.env.possible_next_states[s]:
                        if P[ns] != 0.0:
                            V_ns = 0.0
                            for na in range(self.env.size_of_action_space):
                                V_ns += self.policy[ns,na] * (Q_old[ns,na] - self.bound_g(A[ns,na]))
                            Q_sa += P[ns] * (self.env.reward_function(s,a,ns) + self.gamma * V_ns)
                Q_new[s,a] = Q_sa

        diff = np.max(np.abs(Q_new - Q_old))
        return Q_new, diff


    def backup_psi(self, psi_old):
        """ BAL """
        V = self.compute_V_opt_from_Q(psi_old, self.alpha)
        A = psi_old - V.reshape(-1, 1)
        TQ, _ = self.backup_Q_by_soft_Bellman_opt(psi_old, A)
        psi_new = TQ + self.beta * self.bound_f(A)

        diff = np.max(np.abs(psi_new - psi_old))
        return psi_new, diff
