import glob
import logging
import os
import shutil
import time
from collections import deque
from os import path
from pathlib import Path
import threading
import numpy as np
import torch
# from torch.profiler import profile, record_function, ProfilerActivity
from sacred import Experiment
from sacred.observers import (  # noqa
    FileStorageObserver,
    MongoObserver,
    QueuedMongoObserver,
    QueueObserver,
)
from torch.utils.tensorboard import SummaryWriter

import utils
from a2c import A2C, algorithm
from envs import make_vec_envs
from wrappers import RecordEpisodeStatistics, SquashDones, GlobalizeReward, FlattenObservation
from model import Policy

import wandb

import rware # noqa
import lbforaging # noqa

ex = Experiment(ingredients=[algorithm])
ex.captured_out_filter = lambda captured_output: "Output capturing turned off."
ex.observers.append(FileStorageObserver("./results/sacred"))

logging.basicConfig(
    level=logging.INFO,
    format="(%(process)d) [%(levelname).1s] - (%(asctime)s) - %(name)s >> %(message)s",
    datefmt="%m/%d %H:%M:%S",
)


@ex.config
def config():
    env_name = None
    time_limit = None
    share_reward = False
    wrappers = [
        RecordEpisodeStatistics,
        SquashDones,
        FlattenObservation
    ]
    if(share_reward):
        wrappers.append(GlobalizeReward)
    wrappers = tuple(wrappers)
    dummy_vecenv = False

    num_env_steps = 100e6
    env_configs = {}

    eval_dir = "./results/video/{id}"
    loss_dir = "./results/loss/{id}"
    save_dir = "./results/trained_models/{id}"

    log_interval = 2000
    save_interval = int(5e5)
    # save_interval = None
    eval_interval = int(1e4)
    # eval_interval = None
    episodes_per_eval = 20


for conf in glob.glob("configs/*.yaml"):
    name = str(Path(conf).stem)
    ex.add_named_config(name, conf)

def _squash_info(info, env_name, eval = False):
    info = [i for i in info if i]
    new_info = {}
    keys = set([k for i in info for k in i.keys()])
    keys.discard("TimeLimit.truncated")

    if('MarlGrid' in env_name):
        new_info['episode_reward'] = np.mean(np.array([i['episode_reward'].sum() for i in info]))
        if(eval):
            new_info['episode_length'] = np.mean(np.array([i['episode_length'] for i in info]))
    else:
        for key in keys:
            mean = np.mean([np.array(d[key]).sum() for d in info if key in d]) if 'MSMTC' not in env_name else np.mean([np.array(d[key]).mean() for d in info if key in d])
            new_info[key] = mean

    return new_info


@ex.capture
def evaluate(
    agents,
    monitor_dir,
    episodes_per_eval,
    env_name,
    seed,
    wrappers,
    dummy_vecenv,
    time_limit,
    algorithm,
    env_configs,
    _log,
):
    device = algorithm["device"]

    eval_envs = make_vec_envs(
        env_name,
        env_configs,
        seed,
        dummy_vecenv,
        episodes_per_eval,
        env_configs["time_limit"] if "time_limit" in env_configs.keys() else time_limit,
        wrappers,
        device,
        env_properties = algorithm['env_properties']
    )

    n_obs = eval_envs.reset()
    n_recurrent_hidden_states = [
        torch.zeros(
            episodes_per_eval, agent.model.recurrent_hidden_state_size, device=device
        )
        for agent in agents
    ]
    n_masks = torch.zeros(episodes_per_eval, 1, device=device)

    all_infos = []

    while len(all_infos) < episodes_per_eval:
        with torch.no_grad():
            _, n_action, _, n_recurrent_hidden_states = zip(
                *[
                    agent.model.act(
                        n_obs[agent.agent_id].float() if 'MarlGrid' not in env_name else (torch.tensor(n_obs[agent.agent_id][0]).float(), n_obs[agent.agent_id][1].float()), recurrent_hidden_states.float(), n_masks
                    )
                    for agent, recurrent_hidden_states in zip(
                        agents, n_recurrent_hidden_states
                    )
                ]
            )

        # Obser reward and next obs
        n_obs, _, done, infos = eval_envs.step(n_action)

        n_masks = torch.tensor(
            [[0.0] if done_ else [1.0] for done_ in done],
            dtype=torch.float32,
            device=device,
        )
        for i, info in enumerate(infos):
            if("predator_prey" in env_name or "PredatorPrey" in env_name or "TrafficJunction" in env_name or "MSMTC" in env_name or "MarlGrid" in env_name):
                if('episode_reward' in info.keys()): 
                    all_infos.append(info)
            else:
                if info:
                    all_infos.append(info)

    eval_envs.close()
    info = _squash_info(all_infos, env_name, eval = True)
    # print(info)
    if("TrafficJunction" in env_name):
        return info['episode_reward'], info['success']
    elif("MarlGrid" in env_name):
        return info['episode_reward'], info['episode_length']
    return info['episode_reward']


#@profile
def true_main(
    _run,
    _log,
    num_env_steps,
    env_name,
    seed,
    algorithm,
    dummy_vecenv,
    time_limit,
    env_configs,
    wrappers,
    save_dir,
    eval_dir,
    loss_dir,
    log_interval,
    save_interval,
    eval_interval,
    share_reward,
):
    # set seed
    utils.set_seed(seed)

    # Init wandb
    if(algorithm['env_properties'] != None): 
        group_name =  algorithm["algorithm_name"] + "_" + env_name + '_sr' + str(algorithm['env_properties']['sensor_range']) + '_' + 'rqs' + str(algorithm['env_properties']['request_queue_size']) + '_' + ("_shared_reward" if share_reward else "") + ("_subproc" if dummy_vecenv else "")
    else:
        group_name =  algorithm["algorithm_name"] + "_" + env_name + ("_shared_reward" if share_reward else "") + ("_subproc" if dummy_vecenv else "")
    print("Algo to run: {}".format(group_name))
    
    run = wandb.init(project='cacl', entity='ssl_ec_marl', group = group_name, config = algorithm, tags=["comm_alt_arch"])

    # update wandb config 
    wandb.config.update({"num_env_steps" : num_env_steps, "seed" : seed})

    # if loss_dir:
    #     loss_dir = path.expanduser(loss_dir.format(id=str(seed)))
    #     if(os.path.isdir(loss_dir) == False):
    #         os.mkdir(loss_dir)
    #     utils.cleanup_log_dir(loss_dir)
    #     writer = SummaryWriter(loss_dir)
    # else:
    #     writer = None

    eval_dir = path.expanduser(eval_dir.format(id=str(seed)))
    if(os.path.isdir(eval_dir) == False):
        os.mkdir(eval_dir)
    save_dir = path.expanduser(save_dir.format(id=str(seed)))
    if(os.path.isdir(save_dir) == False):
        os.mkdir(save_dir)

    utils.cleanup_log_dir(eval_dir)
    utils.cleanup_log_dir(save_dir)
    # torch.set_num_threads(1)
    envs = make_vec_envs(
        env_name,
        env_configs,
        seed,
        dummy_vecenv,
        algorithm["num_processes"],
        env_configs["time_limit"] if "time_limit" in env_configs.keys() else time_limit,
        wrappers,
        algorithm["device"],
        env_properties= algorithm['env_properties']
    )
    # obses = envs.reset()
    # print(envs.observation_space)
    # # print(len(obses))
    # # print(len(obses[0]))
    # # print(np.concatenate(obses).shape)
    # # print(obses[0].size())

    agents = [
        A2C(env_name, i, osp, asp)
        for i, (osp, asp) in enumerate(zip(envs.observation_space, envs.action_space))
    ]
    obs = envs.reset()
    for i in range(len(obs)):
        if('MarlGrid' not in env_name):
            agents[i].storage.obs[0].copy_(obs[i])
        else:
            agents[i].storage.img_obs[0].copy_(obs[i][:][0])
            agents[i].storage.df_obs[0].copy_(obs[i][:][1])
        agents[i].storage.to(algorithm["device"])


    # Assumes equally distributed environment steps 
    env_steps_per_length = int(num_env_steps) / len(algorithm['num_steps_schedule'])
    updates_schedule = []
    for ss in algorithm['num_steps_schedule']:
        updates_schedule.append(env_steps_per_length // ss // algorithm["num_processes"])
    
    assert len(updates_schedule) == len(algorithm['num_steps_schedule'])

    start = time.time()
    # num_updates = (
    #     int(num_env_steps) // algorithm["num_steps"] // algorithm["num_processes"]
    # )

    all_infos = deque(maxlen=12)
    # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    #     with record_function("training loop update"):
    total_update_steps = 0
    test_step = 0
    for u_idx, update_step in enumerate(updates_schedule):
        for j in range(1, int(update_step) + 1):
            total_update_steps += 1
            for step in range(algorithm['num_steps_schedule'][u_idx]):
                # Sample actions
                with torch.no_grad():
                    n_value, n_action, n_action_log_prob, n_recurrent_hidden_states = zip(
                        *[
                            agent.model.act(
                                agent.storage.obs[step] if 'MarlGrid' not in env_name else (agent.storage.img_obs[step].clone().detach(), agent.storage.df_obs[step].clone().detach()),
                                agent.storage.recurrent_hidden_states[step],
                                agent.storage.masks[step],
                            )
                            for agent in agents
                        ]
                    )
                # Obser reward and next obs
                # n_action: num_agents x num_processes
                obs, reward, done, infos = envs.step(n_action)

                test_step += 1
                # envs.envs[0].render()
                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])

                bad_masks = torch.FloatTensor(
                    [
                        [0.0] if info.get("TimeLimit.truncated", False) else [1.0]
                        for info in infos
                    ]
                )
                for i in range(len(agents)):
                    agents[i].storage.insert(
                        obs[i],
                        n_recurrent_hidden_states[i],
                        n_action[i],
                        n_action_log_prob[i],
                        n_value[i],
                        reward[:, i].unsqueeze(1)  / algorithm['reward_scale'],
                        masks,
                        bad_masks,
                    )
                    agents[i].storage.to(algorithm["device"])

                for info in infos:
                    if("predator_prey" in env_name or "PredatorPrey" in env_name or "TrafficJunction" in env_name) or "MarlGrid" in env_name:
                        if('episode_reward' in info.keys()): 
                            all_infos.append(info)
                    else:
                        if info:
                            all_infos.append(info)

            # value_loss, action_loss, dist_entropy = agent.update(rollouts)
            for agent in agents:
                agent.compute_returns()

            # Parallelized version - Python threading library 
            threads = []
            for agent in agents:
                t = threading.Thread(target = agent.update, args = (agents,), )
                t.start()
                threads.append(t)
            [t.join() for t in threads]  

            # for agent in agents:
            #     loss = agent.update([a.storage for a in agents])
    
                # for k, v in loss.items():
                #     if writer:
                #         writer.add_scalar(str("agent{}/{}").format(agent.agent_id, k), v, j)

            for agent in agents:
                agent.storage.after_update()

            if j % log_interval == 0 and len(all_infos) > 1:
                squashed = _squash_info(all_infos, env_name)
                if(u_idx > 0):
                    total_num_steps = 0
                    for u_idx_2 in range(u_idx):
                        total_num_steps += updates_schedule[u_idx_2] * algorithm["num_processes"] * algorithm['num_steps_schedule'][u_idx_2]
                    total_num_steps += (j + 1) * algorithm["num_processes"]  * algorithm['num_steps_schedule'][u_idx]

                else:
                    total_num_steps = (
                        (j + 1) * algorithm["num_processes"] * algorithm["num_steps"]
                    )
                end = time.time()
                _log.info(
                    str("Updates {}, num timesteps {}, FPS {}".format(total_update_steps, total_num_steps, int(total_num_steps / (end - start))))
                )
                _log.info(
                    str("Last {} training episodes mean reward {:.3f}").format(len(all_infos), squashed['episode_reward'].sum())
                )
                # Wandb logging 
                wandb.log({'episode_reward': squashed['episode_reward'].sum(), 'total_num_steps': total_num_steps, 'updates' : total_update_steps})

                # for k, v in squashed.items():
                #     _run.log_scalar(k, v, j)
                all_infos.clear()

            if save_interval is not None and (
                j > 0 and j % save_interval == 0 or j == update_step
            ):
                for agent in agents:
                    save_at_wandb = os.path.join(wandb.run.dir, group_name+ str("_agent{}").format(agent.agent_id))
                    os.makedirs(save_at_wandb, exist_ok=True)
                    agent.save(save_at_wandb, update_step)

            if eval_interval is not None and (
                j > 0 and j % eval_interval == 0 or j == update_step
                
            ):
                if("TrafficJunction" in env_name):
                    eval_reward, eval_success = evaluate(agents, os.path.join(eval_dir, str("u{}").format(j)), env_configs = env_configs,)
                    wandb.log({'eval_updates': total_update_steps, 'eval_reward': eval_reward, 'eval_success': eval_success})
                elif("MarlGrid" in env_name):
                    eval_reward, eval_episode_length = evaluate(agents, os.path.join(eval_dir, str("u{}").format(j)), env_configs = env_configs,)
                    wandb.log({'eval_updates': total_update_steps, 'eval_reward': eval_reward, 'eval_episode_length': eval_episode_length})
                else:
                    eval_reward = evaluate(agents, os.path.join(eval_dir, str("u{}").format(j)), env_configs = env_configs,)
                    wandb.log({'eval_updates': total_update_steps, 'eval_reward': eval_reward})
            # if(j == 1000):
            #     # print("forward time: {}, backward time: {}".format(sum(f_times), sum(b_times)))
            #     print("break")
            #     break
        # print(prof.key_averages().table(sort_by="cpu_time_total"))
        # exit()
        envs.close()

@ex.automain 
def main(
    _run,
    _log,
    num_env_steps,
    env_name,
    seed,
    algorithm,
    dummy_vecenv,
    time_limit,
    env_configs,
    wrappers,
    save_dir,
    eval_dir,
    loss_dir,
    log_interval,
    save_interval,
    eval_interval,
    share_reward,
):
    true_main(_run,_log,num_env_steps,env_name,seed,algorithm,dummy_vecenv,time_limit,env_configs,wrappers,save_dir,eval_dir,loss_dir,log_interval,save_interval,eval_interval,share_reward,)
