from torch.multiprocessing import Manager
import core.utils as utils
import random, sys
import torch, threading
from gp.gp import GP_Population
import numpy as np
from torch.multiprocessing import Process, Pipe
from core.runner import rollout_worker
import copy


def rollout_func(env, popn_id, team):
    """Rollout Worker runs a simulation in the environment to generate experiences and fitness values

        Parameters:
            worker_id (int): Specific Id unique to each worker spun
            task_pipe (pipe): Receiver end of the task pipe used to receive signal to start on a task
            result_pipe (pipe): Sender end of the pipe used to report back results
            noise (object): A noise generator object
            exp_list (shared list object): A shared list object managed by a manager that is used to store experience tuples
            pop (shared list object): A shared list object managed by a manager used to store all the models (actors)
            difficulty (int): Difficulty of the task
            use_rs (bool): Use behavioral reward shaping?
            store_transition (bool): Log experiences to exp_list?


            models_bucket: [popn_id, agent_id]

        Returns:
            None
    """

    FITNESS_AVERAGING = 5
    fitness = 0
    frames = 0
    for its in range(FITNESS_AVERAGING):

        joint_state = env.reset()
        while True: #unless done

            noise = True if type == 'pg' else False
            #print(joint_state[0:0 + 1, :])
            joint_action = [team[agent_id].forward(joint_state[agent_id:agent_id+1,:], popn_id)[0]  for agent_id in range(len(team))]

            if env.is_discrete: joint_action = [int(j > 0.5) for j in joint_action]
            # joint_action[0] = np.array([joint_action[0], 0.5])
            # joint_action = np.array(joint_action)
            #print(joint_action)

            next_state, reward, done, global_reward = env.step(joint_action)  # Simulate one step in environment
            #State --> [agent_id, obs]
            #JOINT ACTION [agent_id, action]
            #reward --> [agent_id, r]
            #done --> *
            #info --> *
            #print(joint_state.shape, next_state.shape, joint_action.shape, reward.shape)
            step_reward = sum(reward)/len(reward)
            fitness += step_reward

            next_state = np.array(next_state)
            joint_state = next_state
            frames+=1
            #DONE FLAG IS Received
            if done:
                break

    fitness/=(FITNESS_AVERAGING*1.0)
    frames/=(FITNESS_AVERAGING*1.0)

    #Send back id, fitness, total length and shaped fitness using the result pipe
    return fitness, frames

class Agent:
    """Learner object encapsulating a local learner

        Parameters:
        algo_name (str): Algorithm Identifier
        state_dim (int): State size
        action_dim (int): Action size
        actor_lr (float): Actor learning rate
        critic_lr (float): Critic learning rate
        gamma (float): DIscount rate
        tau (float): Target network sync generate
        init_w (bool): Use kaimling normal to initialize?
        **td3args (**kwargs): arguments for TD3 algo

        Population of Models self.popn --> [popn_id, agent_id]
        Rollout Team self.rollout_team --> [1, agent_id]
        AutoRewards: self.reward_recipe --> [agent_id]
        Fitnesses: self.fitnesses --> [popn_id, agent_id, []]
        Champ_ID: self.champ_id --> [agent_id]


    """

    def __init__(self, args, model_constructior, env_constructor):
        self.args = args
        self.manager = Manager()


        ########Initialize population
        self.popn = []
        for _ in range(args.config.num_agents):
            self.popn.append(GP_Population(args.config.state_dim, args.config.action_dim, self.args.popn_size, args.elite_ratio, args.kill_ratio))

        # ### HOF Team ###
        # self.hof = [GumbelPolicy(args.config.state_dim, args.config.action_dim) for _ in range(args.config.num_agents)]

        #Agent metrics
        self.fitnesses = [[[] for _ in range(args.config.num_agents)] for _ in range(args.popn_size)]
        self.shaped_fitnesses = [[[] for _ in range(args.config.num_agents)] for _ in range(args.popn_size)]

        ###Best Policy HOF####
        self.champ_ind = [0 for _ in range(args.config.num_agents)]


        if args.config.env_choice == 'gfootball':
            from envs.gfootball import GFootball
            self.env = GFootball(args, T=args.config.T if type != 'test' else float('inf'))

        elif args.config.env_choice == 'gym':
            from envs.gym import GymWrapper
            self.env = GymWrapper(args.config.config, frameskip=args.config.frameskip)

        elif args.config.env_choice == 'test':
            from envs.gfootball import Navigation
            self.env = Navigation(args.config.config)

        else:
            Exception('Wrong Env Choice')



    def evo_rollouts(self):
        total_frames = 0; fitnesses = []
        for team_id in range(self.args.popn_size):
            #team = [agent.popn[team_id] for agent in self.popn]
            fit, frames = rollout_func(self.env, team_id, self.popn)
            total_frames += frames
            fitnesses.append(fit)
            for agent_id in range(self.args.config.num_agents):
                self.popn[agent_id].popn[team_id].fitness.values = (fit,)

        champ_team_id = fitnesses.index(max(fitnesses))
        test_fit, _ = rollout_func(self.env, champ_team_id, self.popn)

        print(self.popn[0].popn[champ_team_id])


        return fitnesses, total_frames, test_fit







    def evolve(self):

        # #Fill in fitnesses
        # for popn_id in range(self.args.popn_size):
        # 	for agent_id, gp_pop in enumerate(self.popn):
        # 		gp_popm.popn[popn_id].fitness.values = (utils.list_mean(self.fitnesses[popn_id][agent_id]),)

        #Evolve each recipe independently
        for gp_pop in self.popn:
            gp_pop.evolve(self.args.crossover_prob, self.args.mutation_prob)

        # Reset fitness metrics
        self.fitnesses = [[[] for _ in range(self.args.config.num_agents)] for _ in range(self.args.popn_size)]
        self.shaped_fitnesses = [[[] for _ in range(self.args.config.num_agents)] for _ in range(self.args.popn_size)]

    def update_champ_ind(self):
        for agent_id in range(self.args.config.num_agents):
            fit = np.squeeze(np.array(self.fitnesses)[:, agent_id, :])
            self.champ_ind[agent_id] = np.argmax(fit)

        #print(agent_id, self.champ_ind[agent_id], len(self.popn), len(self.popn[0]))
        #print(fit.shape, self.champ_ind)



    def terminate(self):
        try:
            for p in self.pg_task_pipes: p[0].send('TERMINATE')
        except:
            None

        try:
            for p in self.evo_task_pipes: p[0].send('TERMINATE')
        except:
            None


class TestAgent:
    """Learner object encapsulating a local learner

        Parameters:
        algo_name (str): Algorithm Identifier
        state_dim (int): State size
        action_dim (int): Action size
        actor_lr (float): Actor learning rate
        critic_lr (float): Critic learning rate
        gamma (float): DIscount rate
        tau (float): Target network sync generate
        init_w (bool): Use kaimling normal to initialize?
        **td3args (**kwargs): arguments for TD3 algo


    """
    def __init__(self, args, source):
        return
        self.args = args
        prefix = source
        self.logger = utils.Tracker(args.metric_save, [prefix+args.log_fname], '.csv')


        #### Rollout Actor is a template used for MP #####
        self.manager = Manager()
        self.rollout_team = self.manager.list()
        self.rollout_team.append([GumbelPolicy(args.config.state_dim, args.config.action_dim) for _ in range(args.config.num_agents)])

        ### Best Team ####
        self.best_team = [GumbelPolicy(args.config.state_dim, args.config.action_dim) for _ in range(args.config.num_agents)]

        ######### TEST WORKERS ############
        self.test_task_pipes = [Pipe() for _ in range(args.num_test)]
        self.test_result_pipes = [Pipe() for _ in range(args.num_test)]
        self.test_workers = [Process(target=rollout_worker,
                                     args=(self.args, id, 'test', self.test_task_pipes[id][1], self.test_result_pipes[id][0],
                                           self.rollout_team, False, False))  for id in range(args.num_test)]
        for worker in self.test_workers: worker.start()

        self.trace = []

        self.it = 0
        self.best_score = -float('inf')


    def start_test_rollout(self, agent):

        self.make_champ_team(agent) # Sync the champ policies into the TestAgent
        for p in self.test_task_pipes:
            p[0].send(0)

    def join_test_rollout(self, total_frames):

        test_fits = []
        for p in self.test_result_pipes:
            entry = p[1].recv()
            test_fits.append(utils.list_mean(entry[1]))

        test_mean = utils.list_mean(test_fits)


        self.logger.update([test_mean], total_frames)
        self.trace.append(test_mean)

        self.it+=1

        # #Periodically save policies
        # if self.it % 10 == 0:
        # 	for id, test_actor in enumerate(self.rollout_team[0]):
        # 		torch.save(test_actor.state_dict(), self.args.model_save + str(id) + '_' + self.source + '_' +self.args.actor_fname)
        # 	print("Models Saved")

        #Save best test score
        if test_mean > self.best_score:
            self.best_score = test_mean
            for agent_id in range(len(self.rollout_team[0])):
                utils.hard_update(self.best_team[agent_id], self.rollout_team[0][agent_id])
                self.best_team[agent_id].cpu()
                torch.save(self.best_team[agent_id].state_dict(), self.args.model_save + str(agent_id) + '_best_' + self.source + '_' + self.args.actor_fname + '.pth')

            print("Best Team Saved with Score", test_mean)



    def make_champ_team(self, agent):

        if self.source == 'pg':  #Testing without Evo
            agent.update_rollout_team()
            for agent_id, model in enumerate(agent.rollout_team[0]):
                utils.hard_update(self.rollout_team[0][agent_id], model)

        elif self.source == 'evo':

            for agent_id, champ_net in enumerate(agent.hof):
                utils.hard_update(self.rollout_team[0][agent_id], champ_net)


        else:
            Exception('Unknown source for Test Agent champ team')

    def terminate(self):
        try:
            for p in self.test_task_pipes: p[0].send('TERMINATE')
        except:
            None


