import numpy
import torch
from torch import nn, optim
from modules.agents.Probability import ProbabilityMeasure, ProbabilityEmpiricalMeasure
from modules.agents.Policy import Policy
from modules.train.TrainHelper import discount, unpack_trajectories
from modules.train.Oracle import Oracle
from modules.utils.Log import Logger
from modules.nn.NNModels import VFunctionNN

class FiniteSpaceOracle(Oracle):
    def __init__(self, env, discount_factor: float):
        super().__init__(env, discount_factor)
        self._n_action = env.action_space.n
        self._n_state = env.observation_space.n

    def n_state(self) -> int:
        return self._n_state

    def n_action(self) -> int:
        return self._n_action

    def save(self) -> dict:
        return {"n_state": self.n_state(),
                "n_action": self.n_action(),
                "discount_factor": self.gamma(),
                "name": self.name(),
                "parameters": {}}

    def find_discounted_visitation_frequency(self, trajectories) -> ProbabilityMeasure:
        n_trajectories = len(trajectories)
        rho_k = ProbabilityEmpiricalMeasure(self.n_state(), numpy.zeros(self.n_state()))
        for trajectory in trajectories:
            states, _, _ = unpack_trajectories([trajectory])
            for i in range(len(states)):
                rho_k.add_probability(states[i], (self.gamma()**i)/n_trajectories)
        rho_k.normalize()
        return rho_k


# neural network
class OracleNN(FiniteSpaceOracle):
    def __init__(self, env, discount_factor: float, lr: float = 1e-3):
        super().__init__(env, discount_factor)
        self._name = 'OracleNN'
        # self._value_function = VFunctionNN(1, [2], [nn.Tanh()])
        self._value_function = VFunctionNN(1, [75, 75], [nn.Tanh(), nn.Tanh()])
        self._lr = lr
        self._opt = optim.Adam(self._value_function.parameters(), lr=self._lr)

    def save(self) -> dict:
        return { "lr": self._lr }

    def V(self, states):
        return self._value_function(torch.FloatTensor(states).view(-1, 1)).view(-1)

    def generalized_advantage_estimator(self, trajectories):
        advantages = numpy.zeros([self.n_state(), self.n_action()])
        for tau in trajectories:
            states, actions, rewards = unpack_trajectories([tau])
            values = self.V(states).detach().numpy()
            tds = rewards - values + numpy.append(values[1:] * self.gamma(), 0)
            advs = discount(tds, self.gamma())
            for t in range(len(states)):
                advantages[states[t], actions[t]] += advs[t]
        return advantages

    def fit(self, trajectories):
        for tau in trajectories:
            states, _, _ = unpack_trajectories([tau])

            discounted_returns = torch.FloatTensor(tau["discounted_returns"].copy()).view(-1)
            loss = torch.nn.functional.mse_loss(self.V(states), discounted_returns.detach())

            self._opt.zero_grad()
            loss.backward()
            self._opt.step()

    def _predict(self, episode: int, policy: Policy, trajectories: list, logger: Logger):
        self.fit(trajectories)

        rho_k = self.find_discounted_visitation_frequency(trajectories)
        advantage_function = self.generalized_advantage_estimator(trajectories)

        return advantage_function, rho_k

# SARSA
class OracleSarsa(FiniteSpaceOracle):
    def __init__(self, env, discount_factor: float, learning_rate: float):
        super().__init__(env, discount_factor)
        self._name = 'OracleSARSA'
        self._alpha = learning_rate

    def alpha(self) -> float:
        return self._alpha

    def save(self) -> dict:
        data = super().save()
        data["parameters"]["learning_rate"] = self.alpha()
        return data

    def _predict(self, episode: int, policy: Policy, trajectories: list, logger: Logger):
        # visitation frequency
        rho_k = self.find_discounted_visitation_frequency(trajectories)
        # advantage
        Q = numpy.zeros([self.n_state(), self.n_action()])
        for tau in trajectories:
            states, actions, rewards = unpack_trajectories([tau])
            length_traj = len(states)
            for t in range(length_traj):
                alpha_t = self.alpha()
                s_t = states[t]
                a_t = actions[t]
                r_t = rewards[t]
                if t == length_traj-1:
                    # no future data, so get nothing from the future
                    Q[s_t, a_t] = (1 - alpha_t) * Q[s_t, a_t] + alpha_t * r_t
                else:
                    # future data available
                    s_tt = states[t+1]
                    a_tt = actions[t+1]
                    Q[s_t, a_t] = (1 - alpha_t) * Q[s_t, a_t] + alpha_t * (r_t + self.gamma() * Q[s_tt, a_tt])
        value_function = numpy.array([policy.policy(s).expected_value_function(Q[s]) for s in range(self.n_state())])
        advantage = numpy.array([[Q[s, a] - value_function[s]
                                  for a in range(self.n_action())] for s in range(self.n_state())])
        return advantage, rho_k


# MC
class OracleMC(FiniteSpaceOracle):
    def __init__(self, env, discount_factor: float, learning_rate: float, q_init: float = 0):
        super().__init__(env, discount_factor)
        self._name = 'OracleMC'
        self._alpha = learning_rate
        self.Q = q_init * numpy.ones([self.n_state(), self.n_action()])

    def alpha(self) -> float:
        return self._alpha

    def save(self) -> dict:
        data = super().save()
        data["parameters"]["learning_rate"] = self.alpha()
        return data

    def _predict(self, episode: int, policy: Policy, trajectories: list, logger: Logger):
        # visitation frequency
        rho_k = self.find_discounted_visitation_frequency(trajectories)

        _, _, rewards = unpack_trajectories(trajectories)
        # advantage
        for _ in range(1):
            for tau in trajectories:
                states, actions, rewards = unpack_trajectories([tau])
                length_traj = len(states)
                for t in range(length_traj - 1):
                    alpha_t = self.alpha()
                    gamma_t = self.gamma()
                    s_t = states[t]
                    a_t = actions[t]
                    s_tt = states[t + 1]
                    a_tt = actions[t + 1]
                    r_t = rewards[t]
                    # G_t = tau["discounted_returns"][t]
                    self.Q[s_t, a_t] = (1 - alpha_t) * self.Q[s_t, a_t] + alpha_t * (r_t + gamma_t * self.Q[s_tt, a_tt])
                alpha_t = self.alpha()
                self.Q[states[-1], actions[-1]] = (1 - alpha_t) * self.Q[states[-1], actions[-1]] + alpha_t * rewards[-1]

        value_function = numpy.array([policy.policy(s).expected_value_function(self.Q[s]) for s in range(self.n_state())])
        advantage = numpy.array([[self.Q[s, a] - value_function[s]
                                  for a in range(self.n_action())] for s in range(self.n_state())])
        return advantage, rho_k


# Omniscient oracle
class OracleTrue(FiniteSpaceOracle):
    def __init__(self, env, discount_factor):
        super().__init__(env, discount_factor)
        self._name = 'OracleTrue'

    def save(self) -> dict:
        data = super().save()
        return data

    def _predict(self, episode: int, policy: Policy, trajectories: list, logger: Logger):
        # get model
        P, R = self._transition_model(policy)
        initial_state = numpy.zeros([self.n_state()])
        states, _, _ = unpack_trajectories(trajectories)
        initial_state[states[0]] = 1
        rho = self._visitation_frequency(P, initial_state)
        advantage = self._advantage(P, R)
        return advantage, rho

    def _transition_model(self, pi):
        P = numpy.zeros([self.n_state(), self.n_state()])
        R = numpy.zeros([self.n_state()])
        for i in range(self.n_state()):
            for a in range(self.n_action()):
                for h in range(len(self._env.P[i][a])):
                    p, j, r, _ = self._env.P[i][a][h]
                    P[j, i] += pi(i).get_probability(a) * p
                    R[i] += pi(i).get_probability(a) * p * r
        return P, R

    def _visitation_frequency(self, P, x=None, max_it=1000):
        if x is None:
            # Starts from first state
            x = numpy.zeros([self.n_state()])
            x[0] = 1

        gamma_t = 1
        rho = ProbabilityEmpiricalMeasure(numpy.arange(self.n_state()),
                                             numpy.zeros(self.n_state()))
        for _ in range(max_it):
            x = numpy.dot(P, x)
            rho.add_probability(numpy.arange(self.n_state()), gamma_t * x)
            gamma_t = self.gamma() * gamma_t
        rho.normalize()
        return rho

    def _advantage(self, P, R, decimals=3) -> numpy.ndarray:
        V = numpy.linalg.solve(numpy.identity(self.n_state()) - self.gamma() * P.T, R)
        Q = numpy.zeros([self.n_state(), self.n_action()])
        for s in range(self.n_state()):
            for a in range(self.n_action()):
                for h in range(len(self._env.P[s][a])):
                    p, s_next, r, _ = self._env.P[s][a][h]
                    Q[s, a] += p * (r + self.gamma() * V[s_next])
        
        return numpy.around(numpy.array(
            [[Q[s, a] - V[s] for a in range(self.n_action())] for s in range(self.n_state())]), decimals)


# Omniscient oracle with noise
class OracleTrueNoisy(OracleTrue):
    def __init__(self, env, discount_factor, noise_visitation=0.1, noise_advantage=0.2):
        super().__init__(env, discount_factor)
        self._noise_rho = noise_visitation
        self._noise_adv = noise_advantage

    def _visitation_frequency(self, P, x=None, max_it=1000):
        rho = super()._visitation_frequency(P, x, max_it)
        rho = rho + ProbabilityEmpiricalMeasure(numpy.arange(self.n_state()), numpy.random.normal(0, self._noise_rho, self.n_state()))
        rho.normalize()
        return rho

    def _advantage(self, P, R, decimals=3) -> numpy.ndarray:
        A = super()._advantage(P,R,decimals)
        return A + numpy.random.normal(0, self._noise_adv * (A.max() - A.min()), A.shape)
