import numpy as np
import random
from nn import flaxNNX, flaxNN, MLP
import jax.numpy as jnp
from gym import spaces
from math import prod
import ray
from jax import jit
import pickle
from flax import serialization
import time
import json
from pathlib import Path

# PS paring mode learning schedule
class PSPair():
    def __init__(self,enviroments,min_sampling,solutions,seed=80333):
        self.agents_n = enviroments[0][0]
        self.workers = len(enviroments)
        worker = 0
        sampling = 0 
        self.schedule = []
        run = []
        while not (sampling>=min_sampling and (worker%self.workers==0) ):
            for i in range(solutions):
                worker += 1 
                run.append([[i for _ in range(self.agents_n)]])
                if (worker%self.workers==0):
                    self.schedule.append(run)
                    run = []
                    if(sampling>=min_sampling):
                        break
            sampling += 1

    def get_idx_worker(self):
        return self.schedule




# Creates a Co-Evolutionary shedule for the workers
class CoEvopair():
    def __init__(self,enviroments,min_sampling,solutions,seed=80333):
        self.seed = seed
        self.params = []
        self.agents_n = 0
        random.seed(self.seed)
        population = solutions 
        for envi in enviroments:
            self.agents_n += envi[0]
            modulo = envi[0] % envi[1]
            div = envi[0] // envi[1]
            envparam = []
            for workers in range(envi[1]):
                    envparam.append(div)
            if( modulo > 0):
                for a in range(modulo):
                    envparam[a] += 1
            self.params.append(envparam)

        self.asolfrac = self.agents_n/population
        self.sample_runs = int(np.ceil(min_sampling/self.asolfrac))

        self.min_samples = self.sample_runs*self.agents_n 
        self.solution_runs = self.min_samples // solutions
        self.solution_remainder = self.min_samples % solutions
        print("minimum samples "+str(self.min_samples))
        print("runs "+str(self.sample_runs))
        self.solution_seq = [ i for n in range(self.solution_runs) for i in range(solutions) ]
        self.solution_seq_r = [ i for i in range(solutions) ]
        
    def get_idx(self):
        random.shuffle(self.solution_seq)
        random.shuffle(self.solution_seq_r)
        solution_list = self.solution_seq.copy()
        if(self.solution_remainder > 0):
            solution_list.extend(self.solution_seq_r[:self.solution_remainder])
        return solution_list
    def get_idx_worker(self):
        solution_list = self.get_idx()
        solution_worker_list= [[]*i for i in range(self.sample_runs)]
        counter = 0
        for run in range(self.sample_runs):
            for envi in self.params:
                env_list = []
                for worker in envi:
                    env_list.append(solution_list[counter:counter+worker])
                    counter += worker
                solution_worker_list[run].append(env_list)
        return solution_worker_list

def create_fnnx_for_agent(pz_env,nn_type): #TODO: More types and layers configs
    aspace = list(pz_env.action_spaces.values())[0]
    ospace = list(pz_env.observation_spaces.values())[0]
    batch = np.zeros(ospace.shape,dtype=ospace.dtype)
    if(True): #self identifier
        batch = np.append([0.5],batch)
    batch = jnp.array([batch.flatten()])
    if isinstance(aspace, spaces.Discrete):
        actf = lambda x : np.array(jnp.argmax(x))
        lsize = aspace.n
    elif isinstance(aspace, spaces.Box):
        actf = lambda x : jnp.reshape(jnp.clip(x,a_min=aspace.low,a_max=aspace.high), aspace.shape)
        lsize = prod(aspace.shape)
    else:
        print("action type not supported")

    return actf, flaxNNX(MLP([64, 64, lsize]) ,batch)


# function for specific environments

@ray.remote
class evalZooEnv():
    def __init__(self,env,nn_arch,min_sampling=10,seed=0):
        self.env = env.env()
        actf, self.fnnx = create_fnnx_for_agent(self.env, nn_arch)
        self.env.reset(seed=seed)
        #self.population = population
        self.actf = actf #jit(actf) not working
        self.min_sampling=min_sampling
        self.steps = 0
        self.fitnesses = [[]*i for i in range(self.fnnx.variables_len) ]

    def evaluate(self,work):
        policies = self.fnnx

        fitnesses = [[]*i for i in range(self.fnnx.variables_len) ] #[[]*i for i in range(len(policies.var_dict.items())) ]
        for run in work: #run = [worker1, worker2] but no parallel  workers implemented yet
            self.env.reset()
            rewards = np.zeros(len(self.env.agents))
            for agent in self.env.agent_iter():
                agent_idx = int(agent[-1]) 
                

                observation, reward, done, info = self.env.last()
                rewards[agent_idx] += reward
                if(True): #Parameter Sharing Index
                    observation = np.append([agent_idx/len(self.env.agents)],observation)


                action = self.actf(policies.apply(run[0][agent_idx], jnp.array([observation.flatten()]))) if not done else None
                self.env.step(action)
                self.steps += 1
            for agent,idx in enumerate(run[0]):
                fitnesses[idx].append(rewards[agent])
        return [fitnesses]
    def update_solutions(self,solutions):
        for idx,solution in enumerate(solutions):
            self.fnnx.update_parameters(idx,solution)
    def get_vars_len(self):
        return self.fnnx.variables_len
    def get_total_steps(self):
        return self.steps


#homogenous agents, sequential petting zoo envs, not that big networks or memory saving
class rayTrainer_v1():
    def __init__(self,pz_env,nn_arch,solver,num_workers=20,population=64, min_sampling=10,seed=0,log_pop_on=True):
        self.experiment = time.strftime("%H%M%S_%d%m%y",time.gmtime())
        self.path = Path("experiments/"+self.experiment)
        self.path.mkdir(parents=True,exist_ok=True)
        self.trainer = "rayTrainer_v1"
        self.env = pz_env
        self.population = population
        tmp_env = self.env.env()
        self.nn_arch = nn_arch
        self.envs = [ evalZooEnv.remote(pz_env,nn_arch,min_sampling,seed=seed) for _ in range(num_workers)]
        self.solver = solver(ray.get(self.envs[0].get_vars_len.remote()),popsize=self.population,seed=seed)
        self.num_workers = num_workers
        self.generations = 0
        tmp_env.reset()
        self.pairing = CoEvopair([[tmp_env.num_agents,1] for _ in range(num_workers)], min_sampling, self.population)
        self.best_model = None
        self.best_mean_max_reward = None
        self.log = []
        self.log_pop = []
        self.eval_log = []
        self.log_pop_on=log_pop_on
        self.total_timesteps = 0
        argspath = self.path / "args.json"
        argspath.write_text(json.dumps({"env":self.env.__name__, "population":self.population,"solver":type(self.solver).__name__, "nnarch":str(self.nn_arch), "trainer":self.trainer, "min_sampling":min_sampling}))
        del tmp_env

    def train(self,num_generations):
        learn_time_start = time.perf_counter()
        for gen in range(num_generations):
            solutions = self.solver.ask()
            futures = [env.update_solutions.remote(solutions) for env in self.envs]
            ray.get(futures)   

            fitnessll = []

            pairing_runs = self.pairing .get_idx_worker()
            env_works = [[]*i for i in range(self.num_workers) ]
            for run in pairing_runs:
                for i,env_work in enumerate(run):
                    env_works[i].append(env_work)
            rays = [] 
            learn_time  = time.perf_counter() - learn_time_start
            sample_time_start = time.perf_counter()
            for work_env in list(zip(env_works,self.envs)):
                rays.append(work_env[1].evaluate.remote(work_env[0]))
            
            fitnessll = ray.get(rays)
            fitness = np.zeros(self.population)
            count = np.zeros(self.population,dtype=int)
            for fitnessl in fitnessll:
                for i,samples in enumerate(fitnessl[0]):
                    for sample in samples:
                        if sample or sample==0.0:
                            fitness[i] += sample
                            count[i] += 1
            for i in range(self.population):
                fitness[i] = fitness[i]/count[i]

            sample_time = time.perf_counter() - sample_time_start
            learn_time_start = time.perf_counter()

            self.solver.tell(fitness)
            #timesteps
            

            max_fitness = np.max(fitness)
            mean_fitness = np.mean(fitness) 
            if self.best_mean_max_reward is None:
                self.best_mean_max_reward = [mean_fitness,max_fitness]
                self.best_model = solutions[np.argmax(fitness)]
            else:
                if(max_fitness > self.best_mean_max_reward[1] or (max_fitness == self.best_mean_max_reward[1] and mean_fitness == self.best_mean_max_reward[0])):
                    self.best_mean_max_reward = [mean_fitness,max_fitness]
                    self.best_model = solutions[np.argmax(fitness)]
            self.total_timesteps = int(np.sum(ray.get([ env.get_total_steps.remote() for env in self.envs]),dtype=np.int64)) #care for int32 in evaluators
            self.log.append({"gen":(self.generations+1),"reward_mean":mean_fitness,"reward_max":max_fitness,"reward_min":np.min(fitness),"reward_std":np.std(fitness),"timesteps":self.total_timesteps,"learntime":learn_time,"sampletime":sample_time,"best_max_mean":self.best_mean_max_reward[0],"best_max":self.best_mean_max_reward[1]})
            print(json.dumps(self.log[-1], indent=4))
            logpath = self.path / "log.json"
            logpath.write_text(json.dumps(self.log,indent=2))
            if(self.log_pop_on):
                self.log_pop.append(fitness.tolist())
                logpoppath = self.path / "pop_log.json"
                logpoppath.write_text(json.dumps(self.log_pop,indent=2))
            self.generations += 1
    
    def save(self):
        path = "experiments/"+self.experiment+"/"+time.strftime("%d_%H%M%S",time.gmtime())+"_"+self.trainer
        pickle.dump(self.__dict__,open(path,"wb"))
        return path
    def load(self,path):
        file = open(path,"rb")
        self.__dict__.update(pickle.load(file))

    def save_model(self):
        assert self.best_model is not None , "self.best_model is None"
        path = "experiments/"+str(self.experiment)+"/"+str(self.generations)+"_model"
        pickle.dump(self.best_model, open(path,"wb"))
        return path
    def load_model(self,path):
        file = open(path,"rb")
        self.best_model = pickle.load(file)
    def render(self, episodes=10, seed=1, delay=0.0):
        assert self.best_model is not None , "self.best_model is None"
        env_render = self.env.env()
        
        actf, fnn = create_fnnx_for_agent(env_render,self.nn_arch)
        fnn.update_parameters(0, self.best_model)
        for episode in range(episodes):
            env_render.reset()
            for agent in env_render.agent_iter():
                agent_idx = int(agent[-1]) 
                

                observation, reward, done, info = env_render.last()
                if(True): #Parameter Sharing Index
                    observation = np.append([agent_idx/len(env_render.agents)],observation)

                action = actf(fnn.apply(0, jnp.array([observation.flatten()]))) if not done else None
                env_render.step(action)
                env_render.render()
                if delay > 0.0 :
                    time.sleep(delay_ms/1000.0)


    def eval(self, episodes=32, seed=0): #TODO: integrate seeds and run eval maybe in parallel?
        assert self.best_model is not None , "self.best_model is None"
        env_eval = self.env.env()
        env_eval.reset(seed=seed)
        actf, fnn = create_fnnx_for_agent(env_eval,self.nn_arch)
        fnn.update_parameters(0, self.best_model)
        episode_mean_reward = []
        episode_variance_reward = []
        for episode in range(episodes):
            env_eval.reset()
            rewards = np.zeros(len(env_eval.agents))
            for agent in env_eval.agent_iter():
                agent_idx = int(agent[-1]) 
                

                observation, reward, done, info = env_eval.last()
                rewards[agent_idx] += reward
                if(True): #Parameter Sharing Index
                    observation = np.append([agent_idx/len(env_eval.agents)],observation)

                action = actf(fnn.apply(0, jnp.array([observation.flatten()]))) if not done else None
                env_eval.step(action)
            episode_mean_reward.append(np.mean(rewards))
            episode_variance_reward.append(np.std(rewards))
        self.eval_log.append({"gen":(self.generations),"reward_mean":np.mean(episode_mean_reward),"reward_std":np.sqrt(np.mean(episode_variance_reward)),"timesteps":self.total_timesteps,"best_max_mean_sampling":self.best_mean_max_reward[0],"best_max_sampling":self.best_mean_max_reward[1],"episodes":episodes})
        print(json.dumps(self.eval_log[-1], indent=4))
        evalpath = self.path / "eval_log.json"
        evalpath.write_text(json.dumps(self.eval_log,indent=2))

#small hacky class to enable parametere sharing for experiment
class rayTrainer_ps(rayTrainer_v1):
    def __init__(self,pz_env,nn_arch,solver,num_workers=20,population=64, min_sampling=10,seed=0,log_pop_on=True):
        self.experiment = time.strftime("%H%M%S_%d%m%y",time.gmtime())
        self.path = Path("experiments/"+self.experiment)
        self.path.mkdir(parents=True,exist_ok=True)
        self.trainer = "rayTrainer_v1"
        self.env = pz_env
        self.population = population
        tmp_env = self.env.env()
        self.nn_arch = nn_arch
        self.envs = [ evalZooEnv.remote(pz_env,nn_arch,min_sampling,seed=seed) for _ in range(num_workers)]
        self.solver = solver(ray.get(self.envs[0].get_vars_len.remote()),popsize=self.population,seed=seed)
        self.num_workers = num_workers
        self.generations = 0
        tmp_env.reset()
        self.pairing = PSPair([[tmp_env.num_agents,1] for _ in range(num_workers)], min_sampling, self.population)
        self.best_model = None
        self.best_mean_max_reward = None
        self.log = []
        self.eval_log = []
        self.log_pop = []
        self.log_pop_on=log_pop_on
        self.total_timesteps = 0
        argspath = self.path / "args.json"
        argspath.write_text(json.dumps({"env":self.env.__name__, "population":self.population,"solver":type(self.solver).__name__, "nnarch":str(self.nn_arch), "trainer":self.trainer, "min_sampling":min_sampling}))
        del tmp_env

def load_trainer(path):
    file = open(path,"rb")
    return pickle.load(file)


#demo
from pettingzoo.sisl import pursuit_v4, multiwalker_v9
from ea import OpenES
import jax.numpy as jnp
from gym import spaces
from math import prod
if __name__ == '__main__':
    ray.init()
    # Comperability experiments
    for experiment in [{"env":pursuit_v4,"iterations":500},{"env":multiwalker_v9,"iterations":500}]:
        for seed in [126291,241241516,10274]:
            trainer = rayTrainer_v1(experiment["env"],MLP,OpenES,num_workers=30,population=64,min_sampling=10,seed=seed,log_pop_on=False)
            trainer.train(experiment["iterations"])

    # Comperability experiments
    for experiment in [{"env":multiwalker_v9,"iterations":2000}]:
        for seed in [241241516]:
            trainer = rayTrainer_v1(experiment["env"],MLP,OpenES,num_workers=30,population=64,min_sampling=10,seed=seed)
            trainer.train(experiment["iterations"])
    for experiment in [{"env":multiwalker_v9,"iterations":2000}]:
        for seed in [241241516]:
            trainer = rayTrainer_ps(experiment["env"],MLP,OpenES,num_workers=30,population=64,min_sampling=10,seed=seed)
            trainer.train(experiment["iterations"])


    