import numpy as np
import argparse
# from stable_baselines.common.env_checker import check_env
import asyncio
from longroad.envs import IntegerRoadEnv,ZooIntegerRoadEnv,IntegerRoadRaw

import ray

from es import SimpleGA, CMAES, OpenES
from model import make_model
import random
import time
import json
import pandas as pd
import matplotlib.pyplot as plt
from collections import namedtuple #modelconfig from estools
Game = namedtuple('Game', ['env_name', 'time_factor', 'input_size', 'output_size', 'layers', 'activation', 'noise_bias', 'output_noise', 'rnn_mode'])


class CoEvopair():
    def __init__(self,enviroments,min_sampling,solutions,seed=80333):
        self.seed = seed
        self.params = []
        self.agents_n = 0
        random.seed(self.seed)

        for envi in enviroments:
            self.agents_n += envi[0]
            modulo = envi[0] % envi[1]
            div = envi[0] // envi[1]
            envparam = []
            for workers in range(envi[1]):
                    envparam.append(div)
            if( modulo > 0):
                for a in range(modulo):
                    envparam[a] += 1
            self.params.append(envparam)

        self.asolfrac = self.agents_n/population
        self.sample_runs = int(np.ceil(min_sampling/self.asolfrac))

        self.min_samples = self.sample_runs*self.agents_n 
        self.solution_runs = self.min_samples // solutions
        self.solution_remainder = self.min_samples % solutions
        self.solution_seq = [ i for n in range(self.solution_runs) for i in range(solutions) ]
        self.solution_seq_r = [ i for i in range(solutions) ]
        
    def get_idx(self):
        random.shuffle(self.solution_seq)
        random.shuffle(self.solution_seq_r)
        solution_list = self.solution_seq.copy()
        if(self.solution_remainder > 0):
            solution_list.extend(self.solution_seq_r[:self.solution_remainder])
        return solution_list
    def get_idx_worker(self):
        solution_list = self.get_idx()
        solution_worker_list= [[]*i for i in range(self.sample_runs)]
        counter = 0
        for run in range(self.sample_runs):
            for envi in self.params:
                env_list = []
                for worker in envi:
                    env_list.append(solution_list[counter:counter+worker])
                    counter += worker
                solution_worker_list[run].append(env_list)
        return solution_worker_list

ray.init()

@ray.remote
class mlp_serveOne():
    def __init__(self,modelparam,popsize):
        self.models = dict()
        for i in range(popsize):
            self.models[i] = make_model(modelparam)
    def set_sol(self,solutions):
        for i,solution in enumerate(solutions):
            self.models[i].set_model_params(solution)
    
    def get_act(self,obs,sol):
        x = self.models[sol].get_action(obs)
        result = 0 if x[0]>x[1] else 1
        return result

@ray.remote
class mlp_serve():
    def __init__(self,modelparam):
        self.model = make_model(modelparam)
    def set_sol(self,solution):
        self.model.set_model_params(solution)
    def get_act(self,obs):
        x = self.model.get_action(obs)
        result = 0 if x[0]>x[1] else 1
        return result
@ray.remote
class mlp_model():
    def __init__(self,modelparam):
        self.model = make_model(modelparam)
    def get_action(self,obs,solution):
        self.model.set_model_params(solution)
        return self.model.get_action(obs)
@ray.remote(num_cpus=24)
def mlpget(obs,solution):
    model = make_model(modelparam)
    model.set_model_params(solution)
    x = model.get_action(obs)
    result = 0 if x[0]>x[1] else 1
    return result
@ray.remote(num_cpus=64)
def modelact(obs,solution):
    x = solution.get_action(obs)
    result = 0 if x[0]>x[1] else 1
    return result



@ray.remote
class parallelEnv():
    def __init__(self,agentsize=10,yellow=False,global_re1=0.01,global_re2=0.1,episode_length=50,population=100,seed=80333):
        self.env = IntegerRoadRaw(agentsize=agentsize,yellow=yellow,global_re1=global_re1,global_re2=global_re2,episode_length=episode_length)
        self.env.seed(seed)
        self.fitnesses = [[]*i for i in range(population) ]
        self.indexar= np.array(range(agentsize))
    def eval_env(self,work,solutions):
        obs = self.env.reset()
        obs=np.c_[obs,  self.indexar/self.env.agentsize]
        shift = [0]
        count = 0
        for i,w in enumerate(work[:-1]):
            increment = len(w)
            shift.append(increment+count)
            count += increment

        rewards = np.zeros(self.env.agentsize) 
        done=False
        while(not done):

            actions = np.zeros(self.env.agentsize,dtype=np.int)
            for agent in range(self.env.agentsize):
                x =  (solutions[work[0][agent]].get_action(obs[agent]))
                result = 0 if x[0]>x[1] else 1
                actions[agent]=result

            
            obs, reward, done, info = self.env.step(actions)
            obs=np.c_[obs, self.indexar/self.env.agentsize] 
            for i,r in enumerate(reward):
                rewards[i] += r
        for agent,idx in enumerate(work[0]):
            self.fitnesses[idx].append(rewards[agent])

    def get_fitnesses(self,idx):
        fitness = self.fitnesses[idx]
        self.fitnesses[idx] = []
        return fitness
    def get_fitty(self):
        fitness = self.fitnesses
        self.fitnesses = [[]*i for i in range(population) ]
        return fitness
    def reset_fitnesses(self):
        self.fitnesses = [[]*i for i in range(population) ]

parser = argparse.ArgumentParser(description='Trail Settings')
parser.add_argument('--agentsize',
    type=int,
    help="agentsize", default=100)
parser.add_argument('--yellow',
    type=int,
    help="yellow phase", default=0)
parser.add_argument('--seeds',
    type=int,
    help="seeds option", default=13232)
parser.add_argument('--it',
    type=int,
    help="generation option", default=500)
parser.add_argument('--envs',
    type=int,
    help="envs option", default=10)
parser.add_argument('--population',
    type=int,
    help="population option", default=250)
parser.add_argument('--sampling',
    type=int,
    help="sampling option", default=16)
parser.add_argument('--method',
    help="method option", default="ga")


population = 250
min_sampling = 16
generations = 10
env_n = 10

args=parser.parse_args()
method = args.method
print(args)
env_n = args.envs
min_sampling = args.sampling
population = args.population
agentsize=args.agentsize
yellow=bool(args.yellow)# True
global_re1 = 0.01 if yellow else 0.1
seeds = 15225
generations = args.it

seed=args.seeds


environments = [[agentsize,1] for _ in range(env_n)]


modelparam = Game(env_name='train',
  input_size=16,
  output_size=2,
  time_factor=0,
  layers=[16, 16],
  activation='tanh',
  noise_bias=0.0,
  output_noise=[False, False, False],
  rnn_mode=False,
)

dummy_model = make_model(modelparam)
param_count = dummy_model.param_count



cmaes = CMAES(param_count,
              popsize=population,
              weight_decay=0.0,
              sigma_init = 0.1
          )
ga = SimpleGA(param_count,                # number of model parameters
               sigma_init=0.5,        # initial standard deviation
               popsize=population,   # population size
               elite_ratio=0.1,       # percentage of the elites
               forget_best=False,     # forget the historical best elites
                weight_decay=0.005,   # weight decay coefficient
              )
oes = OpenES(param_count,                  # number of model parameters
            sigma_init=0.1,            # initial standard deviation
            sigma_decay=0.999,         # don't anneal standard deviation
            learning_rate=0.1,         # learning rate for standard deviation
            learning_rate_decay = 1.0, # annealing the learning rate
            popsize=population,       # population size
            antithetic=False,          # whether to use antithetic sampling
            weight_decay=0.005,         # weight decay coefficient
            rank_fitness=False,        # use rank rather than fitness numbers
            forget_best=False)
del dummy_model

solverdic={"ga":ga,"cmaes":cmaes,"oes":oes}

agents_n = 0
envs = []

log = []

solver = solverdic[method]
pairgen = CoEvopair(environments,min_sampling,population)
timestamp = time.strftime("%m%d%H%M%S")
filepath = "log/"+timestamp+method
filejson = filepath+"_log.json"
filepickle = filepath+"_eval"

for envi in environments:
    envs.append(parallelEnv.remote(agentsize=envi[0],yellow=yellow,global_re1=global_re1,seed=seed,population=population))
learn_time_start = time.perf_counter()

mlpl = [make_model(modelparam) for i in range(population)]
mlp = dict(enumerate(mlpl))


for gen in range(generations+1):  
    solutions = solver.ask()
    iseval = False
    if(gen==range(generations+1)[-1]):
        solutions = [best_mean_max_sol for _ in range(population)]
        iseval = True
    learn_time  = time.perf_counter() - learn_time_start
    sample_time_start = time.perf_counter()


    for i,solution in enumerate(solutions):
        mlp[i].set_model_params(solution)

    fitnessl = [[]*i for i in range(population) ]
    pairing_runs = pairgen.get_idx_worker()
    for pairing in pairing_runs:

        rays = []
        for work in list(zip(pairing,envs)):
            rays.append(work[1].eval_env.remote(work[0],mlp)) ###standard # mlps))
        ray.get(rays)

    fitnessll = []
    for env in envs:
        tt = time.perf_counter()
        fitnessll.append([ray.get(env.get_fitty.remote())])

    #Clean fitnesses and count mean

    fitness = np.zeros(population)
    count = np.zeros(population,dtype=int)
    for fitnessl in fitnessll:
        for i,samples in enumerate(fitnessl[0]):
            for sample in samples:
                if sample:

                    fitness[i] += sample
                    count[i] += 1
    for i in range(population):
        fitness[i] = fitness[i]/count[i]

    sample_time = time.perf_counter() - sample_time_start
    learn_time_start = time.perf_counter()

    solver.tell(fitness)

    mean = np.mean(np.array(fitness))
    maxf = np.max(np.array(fitness))
    minf = np.min(np.array(fitness))
    if(gen==0):
        best_mean = mean
        best_max = maxf
        best_mean_max_sol = solutions[np.argmax(np.array(fitness))]
        best_mean_pop = np.copy(solutions)

    else:
        if(mean >= best_mean):
            best_mean = mean
            best_max = maxf
            best_mean_max_sol = solutions[np.argmax(np.array(fitness))]
            best_mean_pop = np.copy(solutions)


    total_timesteps = (gen+1)*pairgen.sample_runs*env_n*50

    
    if(not iseval):
        log.append({"generation":(gen+1),"timesteps":total_timesteps,"reward_mean":mean,"reward_max":maxf,"reward_min":minf,"leartime_s":learn_time,"sampletime":sample_time,"best_mean":best_mean,"best_max":best_max, "params":vars(args)})
        print(log[-1])
        with open(filejson, 'wt') as out:
            json.dump(log, out, indent=4)
    else:
        column = ['reward_mean','reward_max',"reward_min",'total_timesteps','params','model','population']
        t = pd.DataFrame(columns=column)
        t = t.append(
        pd.DataFrame.from_records([{
            'reward_mean': mean,
            'reward_max': maxf,
            'reward_min': minf,
            'total_timesteps': total_timesteps,
            'params': args,
            'model': best_mean_max_sol,
            'population' : best_mean_pop
        }]))
        print(t)
        t.to_pickle(filepickle)
y = []
err1 = []
err2 = []
for l in log:
    y.append(l["reward_mean"])
    err1.append(l["reward_max"])
    err2.append(l["reward_min"])


x = list(range(len(y)))
plt.plot(x,y,"r")
plt.xlabel("Generations")
plt.ylabel("Episodic Rewards")
plt.fill_between(x,err1,err2,color="skyblue", alpha=0.7)
plt.savefig(filepath+"img.png")

ray.shutdown()


    
