import os
import math
import contextlib
import numpy as np
from collections import defaultdict
from tensorboard.backend.event_processing import event_accumulator
from concurrent.futures import ProcessPoolExecutor as Pool, wait


@contextlib.contextmanager
def silence_stderr():
    import os
    import sys
    stderr_fd = sys.stderr.fileno()
    orig_fd = os.dup(stderr_fd)
    null_fd = os.open(os.devnull, os.O_WRONLY)
    os.dup2(null_fd, stderr_fd)
    try:
        yield
    finally:
        os.dup2(orig_fd, stderr_fd)
        os.close(orig_fd)
        os.close(null_fd)


def read_map_info(map_name):
    map_type, params = map_name.lower().split("_", 1)
    if map_type in ["protoss", "terran", "zerg"]:
        n_agents, _, n_enemy = params.split("_")
        import yaml
        with open(f"configs/sc2_gen_{map_type}.yaml", "r") as f:
            config = yaml.safe_load(f)["env_args"]
            config["capability_config"]["n_units"] = int(n_agents)
            config["capability_config"]["n_enemies"] = int(n_enemy)
        return config
    else:
        raise


def analyze_logs(logs, n_agents, n_enemies):
    total_reward, dead_allies, dead_enemies, battle_won = [], [], [], []
    for info in logs:
        if "total_reward" in info: total_reward.append(info["total_reward"])
        if "dead_allies" in info: dead_allies.append(info["dead_allies"])
        if "dead_enemies" in info: dead_enemies.append(info["dead_enemies"])
        if "battle_won" in info: battle_won.append(info["battle_won"])
    total_reward = sum(total_reward) / len(total_reward) if len(total_reward) > 0 else 0
    dead_allies = sum(dead_allies) / len(dead_allies) / n_agents if len(dead_allies) > 0 else 0
    dead_enemies = sum(dead_enemies) / len(dead_enemies) / n_enemies if len(dead_enemies) > 0 else 0
    battle_won = sum(battle_won) / len(battle_won) if len(battle_won) > 0 else 0
    return total_reward, dead_allies, dead_enemies, battle_won


def chunk(ids, n_chunks=256):
    min_id = min(ids)
    max_id = max(ids)
    min_chunk_id = 0
    new_ids = []
    _ids = []
    for i, id in enumerate(ids):
        _ids.append(i)
        chunk_id = int((id - min_id) / (max_id - min_id) * n_chunks)
        if chunk_id > min_chunk_id:
            min_chunk_id = chunk_id
            new_ids.append(_ids)
            _ids = []
    while len(new_ids) < n_chunks:
        max_id = max(range(len(new_ids)), key=lambda id: len(new_ids[id]))
        if len(new_ids[max_id]) == 1:
            break
        ids = new_ids[max_id]
        new_ids[max_id] = ids[:len(ids)//2]
        new_ids.insert(max_id+1, ids[len(ids)//2:])
    return new_ids


def read_tensorboard_events(log_dir, env_name, algo, eps, seeds):
    logs = []
    expert_logs = []
    for seed in seeds:
        folder = f"{log_dir}/{eps}_{seed}/{env_name}/{algo}"
        is_found = False
        for file in os.listdir(folder):
            if "events" in file:
                if is_found:
                    raise
                path = f"{folder}/{file}"
                print("Reading ...", path)
                event = event_accumulator.EventAccumulator(path)
                event.Reload()
                assert "game_eval/win_rate" in event.Tags()["scalars"]
                assert "game_expert/win_rate" in event.Tags()["scalars"]
                logs += [(x.step, x.value) for x in event.Scalars("game_eval/win_rate")]
                expert_logs += [x.value for x in event.Scalars("game_expert/win_rate")]
                is_found = True
    assert len(logs) > 0, f"Error {log_dir}/{eps}_xxx/{env_name}/{algo}"
    logs.sort(key=lambda x: x[0])
    steps, winrates = zip(*logs)
    expert_winrate = np.mean(expert_logs)
    return steps, winrates, expert_winrate

def smooth(scalars, weight):
    last = 0
    smoothed = []
    for num_acc, next_val in enumerate(scalars):
        last = last * weight + (1 - weight) * next_val
        smoothed.append(last / (1 - math.pow(weight, num_acc+1)))
    return smoothed


def _read_tensorboard_events(*args):
    try:
        return read_tensorboard_events(*args)
    except Exception as e:
        print("Error", e)


def read_logs(log_dir):
    inputs = defaultdict(list)
    for eps in os.listdir(f"{log_dir}"):
        for env_name in os.listdir(f"{log_dir}/{eps}"):
            for algo in os.listdir(f"{log_dir}/{eps}/{env_name}"):
                parts = eps.split("_")
                inputs[(env_name, algo, parts[0])].append(parts[1])
    data = {}
    with Pool() as p:
        all_tasks = []
        for (env_name, algo, eps), seeds in inputs.items():
            task = p.submit(_read_tensorboard_events, log_dir, env_name, algo, eps, seeds)
            all_tasks.append((env_name, algo, eps, task))
        for env_name, algo, eps, task in all_tasks:
            all_steps, all_winrates, expert_winrate = task.result()
            ids = chunk(all_steps)
            steps = [np.mean([all_steps[i] for i in id]) for id in ids]
            winrates = [np.mean([all_winrates[i] for i in id]) for id in ids]
            if eps not in data:
                data[eps] = {}
            if env_name not in data[eps]:
                data[eps][env_name] = {}
            data[eps][env_name][algo.upper()] = dict(
                steps=steps,
                winrates=winrates,
                expert_winrate=expert_winrate,
                winrate=np.mean(winrates[-4:]),
                winrate_std=np.std(winrates[-4:]),
            )
    return data