import gurobipy
from gurobipy import GRB
import numpy as np
from utils import sharpen_to_onehot
from core import TabularCMDP, TabularMDP, DataBuffer

### helper functions

def policy_iteration(mdp, starting_policy=None, early_stopping=100):
    Q_values = np.random.uniform(0, 1/(1-mdp.γ), [mdp.state_space_size, mdp.action_space_size])
    policy = sharpen_to_onehot(Q_values, n_idx=mdp.action_space_size)
    new_policy = starting_policy
    update_n = 0
    while new_policy is None or not np.allclose(policy, new_policy):
        if new_policy is not None: policy = new_policy
        state_mean_rewards = np.sum(policy * mdp.R_bar, axis=1)
        next_state_dist = np.sum(np.expand_dims(policy, -1) * mdp.P, axis=1)
        state_values = np.linalg.solve(np.eye(mdp.state_space_size) - mdp.γ * next_state_dist, state_mean_rewards)
        Q_values = mdp.R_bar + mdp.γ * np.sum(mdp.P * state_values.reshape([1, 1, mdp.state_space_size]), axis=2)
        new_policy = sharpen_to_onehot(Q_values, n_idx=mdp.action_space_size)
        update_n += 1
        if early_stopping is not None and update_n >= early_stopping: break
    return Q_values

def value_iteration(mdp, n=2000, get_intermediate_policies=False):
    if get_intermediate_policies: intermediate_policies = []
    Q_values = np.random.uniform(0, 1/(1-mdp.γ), [mdp.state_space_size, mdp.action_space_size])
    for i in range(n):
        if get_intermediate_policies: intermediate_policies.append(sharpen_to_onehot(Q_values, n_idx=mdp.action_space_size))
        new_Q_values = mdp.R_bar + mdp.γ * np.sum(mdp.P * Q_values.max(axis=-1).reshape([1, 1, mdp.state_space_size]), axis=2)
        Q_values = new_Q_values
    if get_intermediate_policies: return Q_values, intermediate_policies
    return Q_values

def eps_greedify(policy, ε):
    return (1. - ε) * policy + ε * np.ones_like(policy) / policy.shape[-1]

def linear_programming(mdp, backup=False):
    # https://adityam.github.io/stochastic-control/inf-mdp/linear-programming/
    # https://people.eecs.berkeley.edu/~pabbeel/cs287-fa12/slides/mdps-exact-methods.pdf
    
    s0 = mdp.s0 
    S = mdp.state_space_size 
    A = mdp.action_space_size
    gamma = mdp.γ

    lp = gurobipy.Model('Linear Programming for MDP') 
    lp.setParam('OutputFlag', False)

    # useful constants
    Sr = range(S)
    Ar = range(A)

    # initializes occupancy measure, lower bounds by 0
    rho = lp.addVars(S, A, lb=0, name='rho')

    # constraint: \sum_a \rho(s,a) = \mu_0(s) + \gamma \sum_{s',a'} T(s|s',a')\rho(s',a')
    for s in Sr:
        lp.addConstr((1-gamma) * s0[s] + sum(rho[s_prime, a_prime] * gamma * mdp.P[s_prime, a_prime, s] for s_prime in Sr for a_prime in Ar) == sum(rho[s, a] for a in Ar))
    
    if not backup:
        if isinstance(mdp, TabularCMDP):
            lp.addConstr(sum([rho[s, a] * mdp.C[s, a] for s in Sr for a in Ar]) <= mdp.C_limit)

    if not backup:
        # objective: max \sum_{s,a} rho(s,a) r(s,a)
        lp.setObjective(sum([rho[s, a] * mdp.R[s, a] for s in Sr for a in Ar]), GRB.MAXIMIZE)
    else:
        # original LP was infeasible, so now we just solve for cost minimization
        lp.setObjective(sum([rho[s, a] * mdp.C[s, a] for s in Sr for a in Ar]), GRB.MINIMIZE)
    
    lp.optimize()

    if lp.status == GRB.Status.OPTIMAL:
        rho_star = np.zeros((S,A))
        for s in Sr:
            for a in Ar: 
                rho_star[s,a] = rho[s,a].X
        pi_star = sharpen_to_onehot(rho_star, n_idx=A)
        return {'feasible': True, 'obj': lp.objVal, 'rho_star': rho_star, 'pi_star': pi_star}
    elif lp.status == GRB.Status.INF_OR_UNBD:
        # simply return infeasible
        return {'feasible': False, 'obj': 0, 'pi': None}
    else:
        raise Exception('error status: %d' % lp.status)

### Algorithms

class BaseAlgorithm(object):
    CHEATING = False
    def __init__(self, mdp_info, *args, **kwargs):
        state_space_size, action_space_size, γ, s0 = mdp_info
        self.state_space_size = state_space_size
        self.action_space_size = action_space_size
        self.γ = γ
        self.s0 = s0
        if 'mdp' in kwargs:
            self.mdp = kwargs['mdp']

    def ingest_data(self, data): pass
    def get_policy(self): pass

class LinearProgramming(BaseAlgorithm):
    CHEATING = True 
    def __init__(self, mdp_info, kappa=0., *args, **kwargs):
        super().__init__(mdp_info, **kwargs)
        self.mdp_info = mdp_info
        self.kappa = kappa 
    
    def get_policy(self):
        lp_solution = linear_programming(self.mdp)
        if lp_solution['feasible'] is not True:
            print("No feasible solution!")
        rho_star = lp_solution['rho_star']
        policy = rho_star / rho_star.sum(axis=1, keepdims=True)
        # policy = sharpen_to_onehot(rho_star, self.action_space_size)
        return policy, rho_star, True 
    def evaluate_policy(self, policy, rho_star):
        reward = np.sum(rho_star * self.mdp.R)
        cost = np.sum(rho_star * self.mdp.C)
        return reward, cost

class ReplayLP(BaseAlgorithm):
    def __init__(self, mdp_info, kappa=0., *args, **kwargs):
        super().__init__(mdp_info, **kwargs)
        self.mdp_info = mdp_info
        self.kappa = kappa 
        self.rb = DataBuffer(self.state_space_size, self.action_space_size, self.mdp_info[-1])
    
    def ingest_data(self, data):
        for datum in data:
             self.rb.add_data(*datum)
        C_hat, P_hat = self.rb.to_empirical_R_P()
        # Assumes reward and cost given
        C_hat = np.array(self.mdp.C)
        R_hat = np.array(self.mdp.R) 
        C_limit = self.mdp.C_limit
        if self.kappa > 0.:
            C_hat += self.kappa / np.sqrt(self.rb.data_counts)
        C_hat = np.clip(C_hat, 0., 10.)
        self.empirical_mdp = TabularCMDP(self.state_space_size, self.action_space_size,
         R_hat, C_hat, P_hat, self.γ, self.s0, C_limit=C_limit)

    def stationary_dist(self, policy, empirical=False):
        if empirical:
            P = self.empirical_mdp.P 
        else:
            P = self.mdp.P
        assert policy.shape == (self.state_space_size, self.action_space_size)
        policy_transition_mat = np.sum(P * np.expand_dims(policy, -1), axis=1)
        a = np.eye(policy_transition_mat.shape[0]) - policy_transition_mat
        a = np.vstack((a.T, np.ones(policy_transition_mat.shape[0])))
        b = np.matrix([0] * policy_transition_mat.shape[0] + [1]).T
        dist = np.linalg.lstsq(a, b, rcond=-1)[0]
        dist = np.array(dist).reshape([self.state_space_size]) / np.sum(dist)
        dist = np.expand_dims(dist, 1) * policy
        dist[dist < 1e-10] = 0.
        dist /= np.sum(dist)
        return dist

    def get_policy(self):
        feasible = False 
        try: 
            lp_solution = linear_programming(self.empirical_mdp)
            feasible = lp_solution['feasible']
        except:
            feasible = False
        if feasible is not True:
            # get the minimum cost policy 
            lp_solution = linear_programming(self.empirical_mdp, backup=True)
            rho_star = lp_solution['rho_star']
            policy = rho_star / rho_star.sum(axis=1, keepdims=True)

        else:
            rho_star = lp_solution['rho_star']
            policy = rho_star / rho_star.sum(axis=1, keepdims=True)
        return policy, rho_star, feasible 
    
    def evaluate_policy(self, policy, rho_star):
        rho_star = self.stationary_dist(policy)
        reward = np.sum(rho_star * self.mdp.R)
        cost = np.sum(rho_star * self.mdp.C)
        return reward, cost 

class ReplayLPAdaptive(BaseAlgorithm):
    def __init__(self, mdp_info, kappa=0., *args, **kwargs):
        super().__init__(mdp_info, **kwargs)
        self.mdp_info = mdp_info
        self.kappa = kappa 
        self.rb = DataBuffer(self.state_space_size, self.action_space_size, self.mdp_info[-1])

        # PID
        self.past_kappa = []

    def ingest_data(self, data):
        for datum in data:
             self.rb.add_data(*datum)
        self.build_empirical_mdp()

    def build_empirical_mdp(self):
        C_hat, P_hat = self.rb.to_empirical_R_P()
        # Assumes reward and cost given
        C_hat = np.array(self.mdp.C)
        R_hat = np.array(self.mdp.R) 
        C_limit = self.mdp.C_limit
        
        if self.kappa > 0.:
            C_hat += self.kappa / np.sqrt(self.rb.data_counts)
        C_hat = np.clip(C_hat, 0., 10.)
        self.empirical_mdp = TabularCMDP(self.state_space_size, self.action_space_size,
         R_hat, C_hat, P_hat, self.γ, self.s0, C_limit=C_limit)

    def stationary_dist(self, policy, empirical=False):
        if empirical:
            P = self.empirical_mdp.P 
        else:
            P = self.mdp.P
        assert policy.shape == (self.state_space_size, self.action_space_size)
        policy_transition_mat = np.sum(P * np.expand_dims(policy, -1), axis=1)
        a = np.eye(policy_transition_mat.shape[0]) - policy_transition_mat
        a = np.vstack((a.T, np.ones(policy_transition_mat.shape[0])))
        b = np.matrix([0] * policy_transition_mat.shape[0] + [1]).T
        dist = np.linalg.lstsq(a, b, rcond=-1)[0]
        dist = np.array(dist).reshape([self.state_space_size]) / np.sum(dist)
        dist = np.expand_dims(dist, 1) * policy
        dist[dist < 1e-10] = 0.
        dist /= np.sum(dist)
        return dist

    def get_policy(self):
        feasible = False 
        while feasible is not True:
            lp_solution = linear_programming(self.empirical_mdp)
            feasible = lp_solution['feasible']
            if feasible is not True:
                self.kappa = self.kappa / 2
                self.build_empirical_mdp()

        rho_star = lp_solution['rho_star']
        policy = rho_star / rho_star.sum(axis=1, keepdims=True)
        return policy, rho_star, feasible 
    
    def evaluate_policy(self, policy, rho_star):
        rho_star_true = self.stationary_dist(policy)
        reward = np.sum(rho_star * self.mdp.R)
        cost = np.sum(rho_star * self.mdp.C)

        self.update_kappa(rho_star_true, rho_star)
        return reward, cost 
    
    def update_kappa(self, rho_star_true, rho_star_fake):
        true_cost = 0 
        estimated_cost = 0 

        # compute expected penalty (prevent inf.)
        expected_penalty = 1 /np.sqrt(self.rb.data_counts)
        expected_penalty = np.clip(expected_penalty, 0., 1)
        expected_penalty = np.sum(rho_star_fake * expected_penalty)

        true_cost = np.sum(rho_star_true * self.mdp.C)
        estimated_cost = np.sum(rho_star_fake * self.empirical_mdp.C)

        self.kappa = (max(0, true_cost - self.mdp.C_limit)) / expected_penalty 

class UniformRandom(BaseAlgorithm):
    def get_policy(self, get_intermediate_policies=False):
        policy = np.ones([self.state_space_size, self.action_space_size]) / self.action_space_size
        if get_intermediate_policies: return policy, []
        return policy

if __name__ == "__main__":
    linear_programming(None)