import os
from typing import Callable, List, Deque
import gym
import torch.nn as nn
import numpy as np
import wandb
import logging
from omegaconf import DictConfig, OmegaConf
from tensorboard.util import tb_logging

def make_env(gym_id: str, seed: int, idx: int, capture_video: bool) -> Callable[[], gym.Env]:
    def thunk() -> gym.Env:
        env = gym.make(gym_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, "videos/")
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env
    
    return thunk

def make_ff_net(inp_dim: int, out_dim: int, layers: List[int], activation_fn: str, final_fn: str):
    assert len(layers) > 0
    funcs = {
        'relu': nn.ReLU(),
        'tanh': nn.Tanh(),
        'selu': nn.SELU(),
        'linear': nn.Identity(),
    }
    activation_fn = funcs[activation_fn]
    final_fn = funcs[final_fn]

    num_layers = len(layers)
    net_layers = [inp_dim] + list(layers) + [out_dim]
    net = []
    for l, (in_feats, out_feats) in enumerate(zip(net_layers[:-1], net_layers[1:])):
        net.append(nn.Linear(in_feats, out_feats))
        if l == num_layers:
            net.append(final_fn)
        else:
            net.append(activation_fn)
    return nn.Sequential(*net)

def polyak_average(params, target_params, tau):
    for param, target_param in zip(params, target_params):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

class _IgnoreTensorboardPathNotFound(logging.Filter):
    def filter(self, record):
        assert record.name == "tensorboard"
        if "No path found after" in record.msg:
            return False
        return True

def wandb_init_wrapper(cfg: DictConfig):
    import socket
    wandb_cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    wandb_cfg["host"] = socket.gethostname()
    wandb.init(
        project=cfg.wandb.project,
        entity=cfg.wandb.entity,
        config=wandb_cfg,
        sync_tensorboard=True,
        name=cfg.wandb.name,
        monitor_gym=True,
        save_code=True,
    )
    tb_logger = tb_logging.get_logger()
    tb_logger.addFilter(_IgnoreTensorboardPathNotFound())
    # copy overrides so that we can reproduce fail runs easily
    wandb.run.save(os.path.join(os.getcwd(), ".hydra/overrides.yaml"))
    with open(f'wandb_id.txt', 'w') as f:
        f.write(f'{wandb.run.id}')

def wandb_finish_wrapper():
    wandb.finish()

def avg_max_50p(q: Deque):
    x = np.array(q)
    x.sort()
    return x[x.size//2:].mean()
