import numpy as np
import os
from os.path import dirname, abspath
import sys
import yaml
import json
import datetime
import torch
#torch.autograd.set_detect_anomaly(True)
import random
from tensorboard_logger import configure, log_value # 对TF有依赖，后续替掉
from copy import deepcopy

project_direc = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_direc)

from replay_buffer import ReplayBuffer
from envs.Overcooked_Env_new import Overcooked_NEW
from BCAgent import BCAgent
score = 0

def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
def config_dict_update(params,config_dict):
    if len(params) == 1:
        return config_dict
    else:
        for param in params[1:]:
            key = param.split('=')[0]
            value = param.split('=')[-1]
            key = key.split('-')[-1]
            if key in config_dict.keys():
                config_dict[key] = value
            else:
                config_dict[key] = value
        return config_dict

class BehaviorClone(object):
    def __init__(self, params):
        self.batch_size = params['batch_size']
        self.num_agents = params['num_agents']
        self.device = params['device']
        self.num_eval = params['num_eval_episodes']
        self.save_video = params['save_video']
        self.obs_dim = params['obs_dim']
        
        #self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.nllloss_func = torch.nn.NLLLoss()
        self.agents = [BCAgent(params).to(self.device) for _ in range(self.num_agents)]
        
    def act(self, observations,deterministic=False):
        observations = torch.Tensor(observations).to(self.device)
        actions = []
        for agent, obs in zip(self.agents, observations):
            agent.eval()
            actions.append(agent.act(obs,deterministic))
            #agent.train()
        return actions
    
    def train(self, replay_buffer, logger, step):
        sample = replay_buffer.sample(self.batch_size)
        obses, actions, rewards, next_obses, dones = sample
        
        loss_all = []
        for agent_i, agent in enumerate(self.agents):
            agent.policy_optimizer.zero_grad()
            input_obs = obses[:,agent_i].reshape(self.batch_size,11,5,28)
            agent_actions = agent.policy(input_obs)
            log_probs = torch.log(agent_actions + 1e-19 )
            loss = self.nllloss_func(log_probs, actions[:,agent_i].squeeze().long())
            #with torch.autograd.detect_anomaly():
            loss.backward()
            torch.nn.utils.clip_grad_norm_(agent.parameters(), 10)
            agent.policy_optimizer.step()
            logger('agent'+str(agent_i)+'_loss', loss.item(), step)
            loss_all.append(loss.item())
        return loss_all
        
    def evaluate(self, env, logger, step, save_path='/results', meta_data=None):
        average_episode_reward = 0
        global score
        if self.save_video:
            env.use_render = True
            env.run_dir = save_path
        for instance in meta_data:
            instruct_layer = meta_data[instance]['instruct_layer']
            for i in range(self.num_eval):
                both_agents_ob, share_obs, available_actions = env.reset()
                both_agents_ob = np.stack(both_agents_ob) 
                both_agents_ob = np.concatenate((both_agents_ob, np.tile(instruct_layer[None, ...], (2, 1, 1, 1))), -1)
                episode_step = 0
                done = False
                episode_reward = 0
                while not done:
                    obs_list = both_agents_ob
                    obs = np.array(obs_list).reshape(2,11,5,28)
                    actions = self.act(obs,True)
                    actions = [[actions[0]],[actions[1]]]
                    
                    both_agents_ob, share_obs, reward, dones, info,available_actions = env.step(actions, test_mode=True)
                    both_agents_ob = np.stack(both_agents_ob) 
                    both_agents_ob = np.concatenate((both_agents_ob, np.tile(instruct_layer[None, ...], (2, 1, 1, 1))), -1)
                    episode_reward += reward[-1]
                    episode_step += 1
                    if dones[0] == True or dones[1] == True:
                        break
                average_episode_reward += episode_reward
                
        average_episode_reward /= (self.num_eval * len(meta_data)) 
        logger('average_episode_reward', average_episode_reward, step)
        if average_episode_reward >= score:
            score = average_episode_reward
            BC.save_models(md_exp_direc, step)
        for agent in self.agents:
            agent.train()
        return average_episode_reward  
    
    def save_models(self, path, step):
        for agent_i, agent in enumerate(self.agents):
            torch.save(agent.state_dict(), "{}/agent_".format(path)+str(agent_i)+"_"+str(step)+".th")

    def load_models(self, path, name):
        for agent_i, agent in enumerate(self.agents):
            agent.load_state_dict(torch.load("{}/agent_".format(path)+str(agent_i)+"_"+name+".th", map_location=lambda storage, loc: storage))
        

if __name__ == '__main__':
    # 输入信息：
    #dataPath = r"/home/ubuntu/zhanglichao/code/chunpeng/NL_Human_AI/Skill_NL_BC/Random3_place_onion_in_pot1/buffer/"
    #exp_name = r"Random3_place_onion_in_pot1"

    data_direc = os.path.join(project_direc, "Layout_SoupCoordination_Datasets")
    deliver_list = {
        "pot1": "Soup_Coordination_deliver_soup_use_pot1",
    }
    place_list = {
        "onion_pot1": "Soup_Coordination_place_onion_in_pot1",
        "tomato_pot1": "Soup_Coordination_place_tomato_in_pot1",
    }

    # build meta_data
    dish_layer, onion_layer, tomato_layer, pot1_layer, pot2_layer = \
        np.zeros((11, 5, 1)), np.zeros((11, 5, 1)), np.zeros((11, 5, 1)), np.zeros((11, 5, 1)), np.zeros((11, 5, 1))
    dish_layer[3, 0] = dish_layer[7, 4] = 255.
    onion_layer[6, 0] = onion_layer[4, 4] = 255.
    tomato_layer[4, 0] = tomato_layer[6, 4] = 255.
    pot1_layer[5, 2] = 255.
 
    deliver_layer = {
        "pot1": np.concatenate((dish_layer, pot1_layer), axis=-1),
    }
    place_layer = {
        "onion_pot1": np.concatenate((onion_layer, pot1_layer), axis=-1),
        "tomato_pot1": np.concatenate((tomato_layer, pot1_layer), axis=-1),
    }

    # 算法参数配置：algs、envs
    params = deepcopy(sys.argv)
    skill_type = params[1]
    exp_name = params[2]
    with open(os.path.join(project_direc, "config", "bc.yaml"), "r") as f:
        try:
            config_dict = yaml.safe_load(f)
        except yaml.YAMLError as exc:
            assert False, "default.yaml error: {}".format(exc)
    
    print('算法参数：')
    config_dict = config_dict_update(params, config_dict)
    print(config_dict)
    
    # 初始化随机数
    set_seed_everywhere(config_dict['seed'])
    
    # 日志文件 
    # use tf logger
    unique_token = "{}_{}__{}".format(config_dict['agent'], skill_type, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_logs_direc = os.path.join(project_direc, "skill_results", "tb_logs", exp_name)
    tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
    configure(tb_exp_direc)
    # write config file
    config_str = json.dumps(config_dict, indent=4)
    with open(os.path.join(tb_exp_direc, "config.json"), "w") as f:
        f.write(config_str)
        
    # 模型保存
    if config_dict['save_model'] == True:
        models_direc = os.path.join(project_direc, "skill_results", "models", exp_name)
        md_exp_direc = os.path.join(models_direc, "{}").format(unique_token)  
        if not os.path.exists(md_exp_direc):
            os.makedirs(md_exp_direc)
            
    # Render保存
    if config_dict['save_model'] == True:
        replays_direc = os.path.join(project_direc, "skill_results", "replays", exp_name)
        rp_exp_direc = os.path.join(replays_direc, "{}").format(unique_token)  
        if not os.path.exists(rp_exp_direc):
            os.makedirs(rp_exp_direc)
    
    # 数据准备阶段
    config_dict['env'] = r'soup_coordination'
    env = Overcooked_NEW(config_dict['env'],seed=config_dict['seed'], featurize_type=("ppo", "ppo"))
    #print('环境参数：')
    
    if skill_type == "place":
        config_dict['obs_dim'] = 11*5*(26+2) 
        config_dict['action_dim'] = 6 
        config_dict['num_agents'] = 2 
        config_dict['meta_dim'] = 2
    elif skill_type == "deliver":
        config_dict['obs_dim'] = 11*5*(26+2)
        config_dict['action_dim'] = 6 
        config_dict['num_agents'] = 2 
        config_dict['meta_dim'] = 2

    # 初始化 ReplayBuffer
    obs_shape = [config_dict['num_agents'],  config_dict['obs_dim']]
    action_shape = [config_dict['num_agents'], 1]
    reward_shape = [config_dict['num_agents'], 1]
    dones_shape = [config_dict['num_agents'], 1]
    replay_buffer = ReplayBuffer(obs_shape=obs_shape,
                                action_shape=action_shape,
                                reward_shape=reward_shape,
                                dones_shape=dones_shape,
                                capacity=0,
                                device=config_dict['device'])
    
    if skill_type == "deliver":
        meta_data = {
            "pot1": {
                "obs_dim": (11, 5, 26),
                "instruct_layer": deliver_layer["pot1"],
                },
        }
    elif skill_type == "place":
        meta_data = {
            "onion_pot1": {
                "obs_dim": (11, 5, 26),
                "instruct_layer": place_layer["onion_pot1"],
            },
            "tomato_pot1": {
                "obs_dim": (11, 5, 26),
                "instruct_layer": place_layer["tomato_pot1"],
            },
        }
    else:
        assert 0

    if skill_type == "deliver":
        for instance in deliver_list:
            data_path = os.path.join(data_direc, deliver_list[instance], 'buffer')
            for content in os.listdir(data_path):
                replay_buffer.append_data_with_metadata(os.path.join(data_path, content), meta_data[instance])
    elif skill_type == "place":
        for instance in place_list:
            data_path = os.path.join(data_direc, place_list[instance], 'buffer')
            for content in os.listdir(data_path):
                replay_buffer.append_data_with_metadata(os.path.join(data_path, content), meta_data[instance])
    else:
        assert 0 
        
    print("数据加载完成！,数据量：",replay_buffer.idx)
    
    # 模型训练
    episode, episode_reward, done = 0, 0, True
    BC = BehaviorClone(config_dict)
    step = 0
    while step < config_dict['num_train_steps']+1:

        loss = BC.train(replay_buffer, log_value, step)
        
        # evaluate model
        if step % config_dict['eval_frequency'] == 0:
            average_episode_reward = BC.evaluate(env, log_value, step, save_path=rp_exp_direc, meta_data=meta_data)
            print('loss: ['+str(step)+r'/'+str(config_dict['num_train_steps'])+']=',np.mean(np.array(loss)),'agent0_loss=',loss[0],'agent1_loss=',loss[1],
            'average_episode_reward=', average_episode_reward)
            
        # # save models
        # if step % config_dict['save_frequency'] == 0:
            # BC.save_models(md_exp_direc, step)
            
        step += 1          
    print("all process end!")
    

    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    