import gym
import parser
import numpy as np 
import pathos.multiprocessing as mp
from scipy.stats import t
import argparse
from copy import deepcopy
from policies import LinearPolicy, LLPolicy

class HDCA(object):
    # Class for managing HDCA Training
    def __init__(self, params):
        # Define HDCA attributes
        self.nPool = params['nPool']
        self.iters = params['iters']
        self.total_coordinates = params['total_coordinates']
        self.decay_iter = params['decay_iter']
        self.decay_factor = params['decay_factor']
        self.search_range = params['search_range']
        self.num_samples = params['num_samples']
        self.alpha = params['alpha']
        self.solved_condition = params['solved_condition']
        self.shift_condition = params['shift_condition']
        self.policy_params = {}

        # Define rollout attributes
        self.num_eps = params['num_eps']
        self.max_ep_len = params['max_ep_len']
        self.shift = params['shift_']
        self.weight_type = params['weight_type']

        # Initiate training environment
        self.env_name = params['env_name']
        self.env = gym.make(params['env_name'])
        self.env.seed(params['env_seed'])       
        self.np_seed = params['numpy_seed']

        # Assign policy to worker
        if params['policy_type'] == "Linear":
            self.policy_params = {"env_name": self.env_name,
                                  "ac_dim": self.env.action_space.shape[0], 
                                  "ob_dim": self.env.observation_space.shape[0],
                                  "weight_type": self.weight_type,
                                  "info": True}
            self.policy = LinearPolicy(self.policy_params)
        elif params['policy_type'] == "LunarLander":
            self.policy_params = {"env_name": self.env_name,
                                  "ac_dim": self.env.action_space.shape[0], 
                                  "ob_dim": self.env.observation_space.shape[0],
                                  "weight_type": self.weight_type,
                                  "info": True,
                                  "start_layer": "W1",
                                  "hl_size": params['hl_size']}
            self.policy = LLPolicy(self.policy_params)
        else:
            raise NotImplementedError
        self.print_hyperparameters()

    def train(self):
        ## Generate initial model performance
        init_mean, init_std = self.f(self.policy, num_episodes=100, shift=self.shift, render=False)
        print(self.env_name)
        print(f"Num Simulations: 100, Initial Mean: {init_mean}, Initial Std Dev: {init_std}")

        # Run training loop
        model_, num_changes = self.hdca(init_mean, init_std, model=self.policy)
        print(f"Found {num_changes} sets of weights to adjust with statistical significance")
        print()

        # Save trained policy
        model_.save_weights()


    def print_env_details(self):
        print("Environment Details")
        print("Observations dimension: "+str(self.env.observation_space.shape))
        print("Actions dimension: "+str(self.env.action_space.shape))
        print("Sample Action: "+str(self.env.action_space.sample()))
        print(f"Action Space High: {self.env.action_space.high}")
        print(f"Action Space Low: {self.env.action_space.low}")
        print("Sample Observations: "+str(self.env.observation_space.sample()))

    def print_hyperparameters(self):
        print(f"{self.env_name}")
        print('---------------')
        print('Hyperparameters')
        print('---------------')
        print(f'Episodes per Weight Test: {self.num_eps}, search_range = {self.search_range}, Num Weight Samples per Iter = {self.num_samples}, Alpha = {self.alpha}, Shift = {self.shift}, Seed = {self.np_seed}')
        print(f'Total # Coordinates Perturbed at Each Iteration: {self.total_coordinates}')

    ##########################################
    #      Sample policy performance by      # 
    #           performing rollout           #
    ##########################################
    def f(self, policy=None, num_episodes=20, shift=0, render=False):
        rewards = np.zeros(num_episodes)
        for it in range(num_episodes):
            state = self.env.reset()
            ep_reward = 0
            done = False
            for i in range(self.max_ep_len):
                if render : self.env.render()

                action = policy.get_action(state)
                state, reward, done, _ = self.env.step(action)
                ep_reward += (reward - shift)
                if done or i == self.max_ep_len-1:
                    rewards[it] = ep_reward
                    break
        return np.mean(rewards), np.std(rewards)

    ##############################
    #      Run Training Loop     #
    ##############################
    def hdca(self, model_mean, model_std, model):
        np.random.seed(self.np_seed)
        num_changes = 0     # Keep track of how many sets of 3 weights get perturbed

        # Initialize multiprocessing pool
        pool = mp.ProcessingPool(self.nPool)      # If number of nodes is not given, will autodetect processors
        print(f"Running on {pool.nodes} Processors")

        for it in range(self.iters):

            # Refresh unseen_coordinates if necessary
            model.refresh_unseen_coordinates(self.total_coordinates)
            
            if (it+1)%self.decay_iter == 0:
                self.search_range *= self.decay_factor
                print(f"Search Range: {self.search_range}")

                # Save trained policy whenever search_range is updated
                model.save_weights()
                
            # Choose n dimensions to search over, each layer consists of a different number
            coordinates_to_perturb = model.pick_coords(self.total_coordinates)

            # Batch of values to try, first will be best so far
            ws = ([deepcopy(model) for weight_batch in range(self.num_samples+1)])

            # Generate perturbed values
            perturbed_values = [np.random.uniform(low=-1.0*self.search_range, high=1.0*self.search_range, size=self.total_coordinates) for sample in range(self.num_samples)]

            # Update copied models to perturbed values
            for index in range(1, len(ws)):     # Loop over models
                ws[index].update_weights(coordinates_to_perturb, perturbed_values[index-1], self.total_coordinates)
            
            values = pool.map(self.f, ws, [self.num_eps]*len(ws), [self.shift]*len(ws))
            model_means = np.array([mean[0] for mean in values])
            model_variances = np.array([np.square(mean[1]) for mean in values])

            # Update model stats
            model_mean = model_means[0]
            current_variance = model_variances[0]
            model_std = np.sqrt(current_variance)
            
            # Perform 1-Sided, 2-Sample T-Test
            potential_models, t_diff = self.perform_t_test(model_means, model_mean, model_variances, model_std**2, self.num_eps, self.alpha)  

            # After t-test, pick optimal (mean, variance) model
            optimal_index = 0
            if len(potential_models) > 0:
                # Set optimal index as most extreme model (max{T - t})
                optimal_index = np.argmax(t_diff)

                # Update current policy
                model = ws[optimal_index]
                model_mean = float(model_means[optimal_index])
                model_std = float(np.sqrt(model_variances[optimal_index]))

                num_changes += 1
                print(f'Iter: {it+1}, Expected Reward: {model_mean:.2f}, SD Estimate: {model_std}')

                # Evaluate unshifted performance and determine if it meets solved condition
                if model_mean > self.shift_condition:
                    unshifted_performance = pool.map(self.f, [model]*100, [1]*100, [0]*100)
                    unshifted_mean = np.mean(np.array([res[0] for res in unshifted_performance]))
                    unshifted_std = np.std(np.array([res[0] for res in unshifted_performance]))
                    print(f"Unshifted Mean, SD for 100 Trials: {unshifted_mean}, {unshifted_std}")

                    # End training if solved condition is met
                    if unshifted_mean >= self.solved_condition:
                        print(f'Iter (Solved): {it+1}, Expected Reward: {unshifted_mean:.2f}, SD Estimate: {unshifted_std}')
                        return model, num_changes
            else:
                if model_mean > self.shift_condition:
                    unshifted_performance = pool.map(self.f, [model]*100, [1]*100, [0]*100)
                    unshifted_mean = np.mean(np.array([res[0] for res in unshifted_performance]))
                    unshifted_std = np.std(np.array([res[0] for res in unshifted_performance]))
                    print(f"Unshifted Mean, SD for 100 Trials: {unshifted_mean}, {unshifted_std}")

                    # End training if solved condition is met
                    if unshifted_mean >= self.solved_condition:
                        print(f'Iter (Solved): {it+1}, Expected Reward: {unshifted_mean:.2f}, SD Estimate: {unshifted_std}')
                        return model, num_changes
                print(f"Iter: {it+1}, No Model Found, Max Mean: {max(model_means)}, Corresponding SD: {np.sqrt(model_variances[np.argmax(model_means)])}")
            
            # Updates layer (if applicable) where coordinates are chosen from
            model.update_layer()

        pool.close()
        pool.join()
        return model, num_changes

    def perform_t_test(self, model_means, model_mean, model_variances, current_variance, eps, alpha):
        # Calculate test-statistics T = (Y1-Y2)/(sqrt(s1^2/N1 + s2^2/N2))
        test_statistics = (model_means - model_mean) / (np.sqrt(model_variances/eps + current_variance/eps))

        # Calculate degrees of freedom df = [(s1^2/N1 + s2^2/N2)^2/((s1^2/N1)^2/(N1-1)+(s2^2/N2)^2/(N2-1))]
        df = np.square(model_variances/eps + current_variance/eps) / ( np.square(model_variances/eps)/(eps-1) + np.square(current_variance/eps)/(eps-1) )

        # Generate critical value t
        t_values = t.ppf(1-alpha, df)

        # Candidates are those where T > t and T > 0
        potential_models = [x for x in range(len(test_statistics)) if test_statistics[x] > t_values[x]]
        t_diff = [test_statistics[i] - t_values[i] for i in range(len(test_statistics))]

        return potential_models, t_diff

if __name__ == '__main__':
    ####################################
    #   Define Necessary Parameters    #
    ####################################
    parser = argparse.ArgumentParser()
    parser.add_argument('--nPool', type=int, default=30)
    parser.add_argument('--iters', type=int, default=150)            
    parser.add_argument('--search_range', type=float, default=.15)
    parser.add_argument('--num_samples', type=int, default=1000)
    parser.add_argument('--total_coordinates', type=int, default=17)
    parser.add_argument('--num_eps', type=int, default=20)
    parser.add_argument('--alpha', type=float, default=.3)
    parser.add_argument('--max_ep_len', type=int, default=1000)
    parser.add_argument('--decay_iter', type=int, default=20)
    parser.add_argument('--decay_factor', type=float, default=.7)

    # for Swimmer-v2 and HalfCheetah-v2 use shift = 0
    # for Hopper-v2 use shift = 1
    parser.add_argument('--shift_', type=int, default=0)
    parser.add_argument('--env_name', type=str, default="HalfCheetah-v2")
    parser.add_argument('--solved_condition', type=float, default=3430)
    parser.add_argument('--shift_condition', type=float, default=3430)     # Condition on when to check if unshifted performance meets solved condition

    # determine type of policy to be trained
    parser.add_argument('--policy_type', type=str, default="Linear")
    parser.add_argument('--weight_type', type=str, default="zeros")
    parser.add_argument('--hl_size', type=int, default=None)

    # Reproducibility
    parser.add_argument('--numpy_seed', type=int, default=31)
    parser.add_argument('--env_seed', type=int, default=31)

    args = parser.parse_args()
    params = vars(args)

    hdca = HDCA(params)
    hdca.train()


