from stable_baselines3 import PPO
from stable_baselines3 import DQN
from algos.arppo.arppo import ARPPO
from algos.artrpo.artrpo import ARTRPO
from algos.arppo.arppo_ft import ARPPOFT
from algos.ardqn.ardqn import ARDQN
#from algos.arqrdqn.arqrdqn import ARQRDQN
#from sb3_contrib import RecurrentPPO
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
import torch
from torch import nn
import pdb
import numpy as np

# longest queue
class LQ:
    def __init__(self, env):
        self.env = env

    def __call__(self, s):
        a = 0 # if all are empty, just choose the first 
        max_job_q = 0 
        for q_idx, q in enumerate(self.env.qs):
            q_num_jobs = q.num_jobs
            if q_num_jobs > max_job_q:
                max_job_q = q_num_jobs
                a = q_idx
        # if still -1, then no-op. may mean all queues are empty
        return a

# largest service rate
class LSR:
    def __init__(self, env):
        self.env = env

    def __call__(self, s):
        a = 0 # if nothing else selected, select first
        max_service_rate = -1

        for q_idx, q in enumerate(self.env.qs):
            q_num_jobs = q.num_jobs
            if q_num_jobs > 0 and q.p > max_service_rate:
                max_service_rate = q.p
                a = q_idx
        return a

# longest connected queue 
class LCQ:
    def __init__(self, env):
        self.env = env
    def __call__(self, obs, t = None):
        return self._get_action(obs)
    
    def _get_action(self, s):
        a = 0 # if all are empty, just choose the first 
        max_job_q = 0
        if isinstance(self.env, VecNormalize):
            qs = self.env.get_attr('qs')[0]
            lens = s[:len(qs)]
            cons = s[len(qs):]
        else:
            qs = self.env.qs
            lens = np.abs(s[:len(self.env.qs)])
            cons = s[len(self.env.qs):]
        #connections = np.split(cons, len(cons) / 2)
        connections = s[len(self.env.qs):]
        for q_idx, q in enumerate(qs):
            q_num_jobs = lens[q_idx]
            if q_num_jobs > max_job_q and connections[q_idx]: # check is_connected flag in one-hot vector for queue
            #if q_num_jobs > max_job_q and connections[q_idx][1]: # check is_connected flag in one-hot vector for queue
                max_job_q = q_num_jobs
                a = q_idx
        return a

    def get_action(self, obs):
        actions = []
        # TODO vectorized?
        for o in obs:
            actions.append(self._get_action(o))
        return np.array(actions)

# maxweight
class MaxWeight:
    def __init__(self, env):
        self.env = env
    def __call__(self, obs, t = None):
        return self._get_action(obs, t)
    
    def _get_action(self, s, t):
        a = 0 # if all are empty, just choose the first 
        max_qp = -1
        lens = np.abs(s[:len(self.env.qs)])
        cons = s[len(self.env.qs):]
        #connections = np.split(cons, len(cons) / 2)
        connections = s[len(self.env.qs):]
        for q_idx, q in enumerate(self.env.qs):
            qp = lens[q_idx] * q.get_service_prob(t)
            if qp > max_qp and connections[q_idx]: # check is_connected flag in one-hot vector for queue
            #if qp > max_qp and connections[q_idx][1]: # check is_connected flag in one-hot vector for queue
                max_qp = qp
                a = q_idx
        return a

# largest service connected queue 
class LSCQ:
    def __init__(self, env):
        self.env = env
    def __call__(self, s, t):
        a = 0 # if all are empty, just choose the first
        lens = np.abs(s[:len(self.env.qs)])
        cons = s[len(self.env.qs):]
        #connections = np.split(cons, len(cons) / 2)
        connections = s[len(self.env.qs):]
        max_service_rate = -1
        for q_idx, q in enumerate(self.env.qs):
            q_num_jobs = lens[q_idx]
            #if q_num_jobs > 0 and connections[q_idx][1] and q.get_service_prob(t) > max_service_rate:
            if q_num_jobs > 0 and connections[q_idx] and q.get_service_prob(t) > max_service_rate:
                max_service_rate = q.get_service_prob(t)
                a = q_idx
        return a

# largest arrival * service 
class LASQ:
    def __init__(self, env):
        self.env = env

    def __call__(self, obs, t = None):
        return self._get_action(obs, t)

    def _get_action(self, s, t):
        a = 0 # if all are empty, just choose the first 
        max_ap = -1
        lens = s[:len(self.env.qs)]
        cons = s[len(self.env.qs):]
        connections = np.split(cons, len(cons) / 2)
        for q_idx, q in enumerate(self.env.qs):
            qp = (1. - q.get_arrival_prob(t)) * q.get_service_prob(t)
            if qp > max_ap and connections[q_idx][1] and lens[q_idx] > 0: # check is_connected flag in one-hot vector for queue
                max_ap = qp
                a = q_idx
        return a

# largest success queue
class LSQ:
    def __init__(self, env):
        self.env = env

    def __call__(self, obs, t = None):
        return self._get_action(obs, t)

    def _get_action(self, s, t):
        a = 0 # if all are empty, just choose the first 
        max_ap = -1
        lens = s[:len(self.env.qs)]
        cons = s[len(self.env.qs):]
        connections = np.split(cons, len(cons) / 2)
        for q_idx, q in enumerate(self.env.qs):
            qp = lens[q_idx] * q.get_service_prob(t) * q.get_connection_prob(t)
            if qp > max_ap and connections[q_idx][1] and lens[q_idx] > 0: # check is_connected flag in one-hot vector for queue
                max_ap = qp
                a = q_idx
        return a

# random
class Random:
    def __init__(self, env, use_connections = False):
        self.env = env
        self.use_connections = use_connections

    def __call__(self, obs, t = None):
        return self._get_action(obs, t)
    
    def _get_action(self, s, t):
        a = 0 # if all are empty, just choose the first 
        lens = s[:len(self.env.qs)]
        connections = s[len(self.env.qs):]
        if self.use_connections:
            potential = (lens > 0) & (connections == 1)
        else:
            potential = (lens > 0)
        pot_queues = np.arange(len(self.env.qs))[potential]
        #if len(np.where((connections == 1) ==  True)[0]) != len(self.env.qs):
        #    pdb.set_trace()
        if len(pot_queues):
            a = np.random.choice(pot_queues)
        return a

class StablebaselinePolicy:
    def __init__(self, name, algo, env,\
        truncated_horizon = 2048,
        gamma = 0.99,
        learning_rate = 3e-4,
        batch_size = 64,
        moving_avg = 0.1,
        nu_bias = 0.1,
        variant = 'mark-algo3',
        pretrained_path = None,
        use_lcq = False,
        policy_kwargs = None,
        replay_epochs = 10,
        normalize_rewards = False,
        learning_starts = 0,
        behavior_policy = None,
        augment_data = False,
        exploration_fraction = 0.1,
        anneal_lr = False,
        train_policy = True):

        self.name = name
        self.env = env
        self.use_lcq = use_lcq
        self.lcq_pi = LCQ(env)
        print ('using LCQ {}'.format(self.use_lcq))
        if algo == 'PPO':
            self.pi = PPO(name,
            self.env, 
            verbose = 0, 
            gamma = gamma, 
            learning_rate = learning_rate,
            batch_size = batch_size, 
            n_steps = truncated_horizon, 
            policy_kwargs = policy_kwargs,
            n_epochs = replay_epochs)
        elif algo == 'AR-PPO':
            if pretrained_path:
                self.pi = ARPPOFT(name, self.env, 
                    batch_size = batch_size, 
                    n_steps = truncated_horizon, 
                    verbose = 0, 
                    learning_rate = learning_rate, 
                    gamma = gamma, 
                    moving_avg = moving_avg, 
                    nu_bias = nu_bias, 
                    variant = variant,
                    policy_kwargs = policy_kwargs,
                    n_epochs = replay_epochs,
                    normalize_rewards = normalize_rewards,
                    pretrained_path = pretrained_path,
                    use_pt_ref = True)
            else:
                self.pi = ARPPO(name, self.env, 
                    batch_size = batch_size, 
                    n_steps = truncated_horizon, 
                    verbose = 0, 
                    learning_rate = linear_schedule(learning_rate, use_schedule = anneal_lr, min_value_limit = 1e-5), 
                    gamma = gamma, 
                    moving_avg = moving_avg, 
                    nu_bias = nu_bias, 
                    variant = variant,
                    policy_kwargs = policy_kwargs,
                    n_epochs = replay_epochs,
                    normalize_rewards = normalize_rewards,
                    train_policy = train_policy)
        elif algo == 'AR-TRPO':
            self.pi = ARTRPO(name, self.env, 
                batch_size = batch_size, 
                n_steps = truncated_horizon,
                verbose = 0, 
                learning_rate = linear_schedule(learning_rate, use_schedule = anneal_lr, min_value_limit = 1e-5), 
                gamma = gamma, 
                policy_kwargs = policy_kwargs,
                variant = variant)
        elif algo == 'DQN':
            self.pi = ARDQN(name,
                self.env, 
                verbose = 0,
                gradient_steps = replay_epochs,
                gamma = gamma, 
                learning_rate = learning_rate,
                batch_size = batch_size, 
                policy_kwargs = policy_kwargs,
                learning_starts = learning_starts,
                train_freq = truncated_horizon,
                behavior_policy = behavior_policy,
                exploration_fraction = exploration_fraction,
                augment_data = augment_data)
        elif algo == 'QR-DQN':
            self.pi = ARQRDQN(name,
                self.env, 
                verbose = 0,
                gradient_steps = replay_epochs,
                gamma = gamma, 
                learning_rate = learning_rate,
                batch_size = batch_size, 
                policy_kwargs = policy_kwargs,
                learning_starts = learning_starts,
                train_freq = truncated_horizon,
                behavior_policy = behavior_policy,
                exploration_fraction = exploration_fraction,
                augment_data = augment_data)
        elif algo == 'Rec-PPO':
            self.pi = RecurrentPPO('MlpLstmPolicy', self.env, 
                batch_size = batch_size, 
                n_steps = truncated_horizon, 
                verbose = 0, 
                learning_rate = learning_rate, 
                gamma = gamma, 
                policy_kwargs = policy_kwargs,
                n_epochs = replay_epochs)

    def learn(self, total_timesteps, callback = None):
        self.pi.learn(total_timesteps = total_timesteps, callback = callback)

    def __call__(self, state, t = None):
        if self.use_lcq:
            ent = self.pi.get_entropy(state).item()
            if ent >= 0.5:
                return self.lcq_pi(state)
        return self.pi.predict(state, deterministic = True)[0]

def linear_schedule(initial_value: float, use_schedule = False,\
    min_value_limit = 0, progress_thresh = 1):
    """
    Linear learning rate schedule.

    :param initial_value: Initial learning rate.
    :return: schedule that computes
      current learning rate depending on remaining progress
    """
    def func(progress_remaining: float) -> float:
        """
        Progress will decrease from 1 (beginning) to 0.

        :param progress_remaining:
        :return: current learning rate
        """
        if use_schedule:
            if progress_remaining <= progress_thresh:
                return max(progress_remaining * initial_value, min_value_limit)
            else:
                return initial_value
        else:
            return initial_value
    return func