
import datetime
import gym
import math
import numpy as np
import itertools
import torch
from sac import SAC_continue,MRF2
from tensorboardX import SummaryWriter
from replay_memory import ReplayMemory,replay_buffer#the last one is for her
import sys
#from config import configInit
from utils import set_seed,load_models
import qnyh
from qnyh.envs.qnyh_small.config import get_map
from mpi4py import MPI
import psutil

import time
import json
import os
import wandb


wandb_flag = True
max_update_num=1.2*pow(10,6)
#2+3=5

def reward_shaping( _state ,idx):#used to generate ex_r before trainning
    pos,v=_state
  
    if idx==0:
        return (abs(pos+1.2))
       
    elif idx==1:
        return (-abs(0.6-pos))
    else: 
        return 0


def eval(env,agent,wanb,update_num,name,args):
    # avg_reward = 0.
    # episodes = 10
    episodes = 10
    if args.agentnum==1:
        avg_reward = 0
        
        for _ in range(episodes):
            state = env.reset()
            episode_reward = 0
            done = False
            while not done:
                #env.render()
                action = agent.select_action(state, evaluate=True)
                next_state, reward, done, info = env.step(action)
                episode_reward += reward
                state = next_state                
            avg_reward += episode_reward

        if MPI.COMM_WORLD.Get_rank() == 0:
            if wandb_flag: wandb.log({'ep_reward/eval_{}'.format('baseline'):avg_reward/episodes,'update':update_num})
            print("----------------------------------------")
            print("Test Episodes: {}, Avg. Reward: {},update num:{}".format(episodes, round(avg_reward/episodes, 2),update_num))
        print("----------------------------------------")
    
    else:
        for idx in range(args.agentnum):
            avg_reward = 0
            for _ in range(episodes):
                state = env.reset()
                episode_reward = 0
                done = False
                while not done:
                    #env.render()
                    if idx < args.agentnum-1:
                        action = agent.select_action_workers(state,idx,evaluate=True)
                    else:
                        action = agent.select_action(state, evaluate=True)
                    next_state, reward, done, info = env.step(action)
                    episode_reward += reward
                    state = next_state                
                avg_reward += episode_reward

            if MPI.COMM_WORLD.Get_rank() == 0:
                if wandb_flag: wandb.log({'ep_reward/eval_{}'.format('worker:{}'.format(idx)):avg_reward/episodes,'update':update_num})
                print("----------------------------------------")
                print("{}Test Episodes: {}, Avg. Reward: {},update num:{}".format('worker:{}'.format(idx),episodes, round(avg_reward/episodes, 2),update_num))
            print("----------------------------------------")
        

def trainner(args):
    #init wandb
    agent_num=args.agentnum
    args.env_name ='MountainCarContinuous-v0'#'Hopper-v3'
    #args.env_name ='MountainCarContinuous-v0'#'HalfCheetah-v2'
    
    env = gym.make(args.env_name)
    set_seed(env,args)


   
    os.environ["WANDB_MODE"] = "offline"

    if wandb_flag:
        wandb.init(
            project="Mountaincar_MRF_920",
            group=args.experiment_name,
            name=args.env_name ,
            config = {
                "seed": args.seed
                }
            )
  
###here
  
    env_eval= gym.make(args.env_name)
    env_eval.seed(args.seed)#+MPI.COMM_WORLD.Get_rank())
    env_eval.action_space.seed(args.seed)
  
    
    if env.observation_space.__class__.__name__!='Box': 
        obs_dim = 1
    else:
        obs_dim= env.observation_space.shape[0]

    if env.action_space.__class__.__name__!='Box': 
        a_dim = 1

    else:
        a_dim= env.observation_space.shape[0]
        
    env_params={
            'obs': obs_dim,#env.observation_space.n,#.shape[0],
            'goal': 1,#env.observation_space.n,#shape[0],
            'action': a_dim,#env.action_space.n,
            'max_timesteps':args.max_timesteps,
            'distance':1
    }
   

    load_flag = args.load_model


    #init memory 
    
    memory=ReplayMemory(args.replay_size, args.seed)
    
   

    if agent_num==1:
        agent = SAC_continue(env_params['obs'],  env.action_space, args)
    else:
        agent = MRF2(env_params['obs'], env.action_space, args)
 
   
    load_models(load_flag,agent,args)

    updates = 0       
    print('begain to train the AI')

    shaping_num = 0#scan_shaping(args.experiment_name)


    state= env.reset()
    episode_reward = 0
    episode_steps= 0
    episode_idx=0
 

    for steps in itertools.count(1):

        if args.start_steps > steps and load_flag == False:
            action = env.action_space.sample()  # Sample random action
        else:
            action = agent.select_action(state)

        next_state, reward, done, info = env.step(action)

        # reward +=reward_shaping(state,1)


        episode_steps+=1
        episode_reward+=reward
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        r=[]
        if agent_num==1:
            r+=[reward]
        else:
            for i in range( agent_num):
                r+=[reward+reward_shaping(state,i)]

        memory.push(np.array(state).reshape(-1), np.array(action).reshape(-1), np.array(r).reshape(-1), np.array(next_state).reshape(-1), np.array(mask).reshape(-1))

        state=next_state

      
        if episode_steps >= env._max_episode_steps or done:
           
            if episode_idx%10==1 and wandb_flag:
                if agent_num==1:
                    wandb.log({'ep_reward/train_baseline':episode_reward,'update':updates})
                else:
                    wandb.log({'ep_reward/train':episode_reward,'update':updates})
            if not args.demonstration and  MPI.COMM_WORLD.Get_rank() == 0:
                if 'nextstate_shaping' in args.experiment_name:
                    print("nsShaping:Episode: {}, total numsteps: {}, episode steps: {}, reward: {},shaping_num:{}".format(episode_idx, steps, episode_steps, episode_reward,shaping_num))
                else:
                    print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {},shaping_num:{}".format(episode_idx, steps, episode_steps, episode_reward,shaping_num))
            #reset
            state= env.reset()
            episode_reward = 0
            episode_steps= 0
            episode_idx+=1

        # updates
        if len(memory) > args.batch_size:
            #train the agents
            if  agent_num==1:
                critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, args.batch_size, updates,inputTuple=False)
                if updates%50==1 and wandb_flag:
                    if agent_num==1:key='baseline'
                    wandb.log({
                    "loss/critic_1_loss_{}".format(key): critic_1_loss,
                    "loss/critic_2_loss_{}".format(key): critic_2_loss,
                    "loss/policy_loss_{}".format(key): policy_loss,
                    "loss/entropy_loss_{}".format(key): ent_loss,
                    "loss/alpha_{}".format(key): alpha,
                    'update':updates
                    })
            else:
                
                critic_1_loss, critic_2_loss, worker_loss,policy_loss, ent_loss, alpha,lamda = agent.update_parameters(memory, args.batch_size, updates,inputTuple=False)
                if updates%50==1 and wandb_flag:
                    wandb.log({
                    "loss/critic_1_loss": critic_1_loss,
                    "loss/critic_2_loss": critic_2_loss,
                    "loss/policy_loss": policy_loss,
                    "loss/entropy_loss": ent_loss,
                    
                    "loss/lamda": lamda,
                    'loss/workers_loss':worker_loss,
                    'update':updates
                    })
                    for index,a in enumerate(alpha):
                        wandb.log({'loss/worker{}'.format(index):a,'update':updates})
            updates+=1       
                

            if updates % 1000 == 1 and steps>pow(10,4) and not args.demonstration and MPI.COMM_WORLD.Get_rank() == 0:
               
                agent.save_checkpoint(env_name=args.env_name, suffix='{}_{}_{}'.format(args.experiment_name,args.seed,'distributeAgents'))
                # print(key+':')
                   
                eval(env_eval,agent,wandb,updates,None,args)
                print("save successful")
        
        # finish 
        if updates>max_update_num:
            print('process finished')
            break
  

       

if __name__=='__main__':
    args = get_map('args') 
    # p = psutil.Process()
    # p.cpu_affinity([110])#[args.start_cpu+MPI.COMM_WORLD.Get_rank()])
    # p.cpu_affinity([15])
    
    trainner(args)



