from typing import Any, Dict

import numpy as np
import torch

from args import DatasetConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % (2**32) + worker_id
    torch.manual_seed(worker_seed)
    numpy_seed = int(worker_seed % (2**32 - 1))  # Optional, in case you also use numpy in the DataLoader
    np.random.seed(numpy_seed)


def build_bandit_data_filename(env: str, n_envs: int, config: Dict[str, float], mode: int) -> str:
    """
    Builds the filename for the bandit data.
    Mode is either 0: train, 1: test, 2: eval.
    """
    filename = "datasets/trajs_"
    filename += env
    filename += f"_envs{n_envs}"
    if mode != 2:
        filename += f"_hists{config['n_hists']}"
        filename += f"_samples{config['n_samples']}"
    filename += f"_H{config['H']}"
    filename += f"_d{config['dim']}"
    filename += f"_var{config['var']}"
    filename += f"_cov{config['cov']}"
    if mode == 0:
        filename += "_train"
    elif mode == 1:
        filename += "_test"
    elif mode == 2:
        filename += "_eval"
    return filename


def build_bandit_model_filename(env, config: dict[str, Any]):
    """
    Builds the filename for the bandit model.
    """
    filename = env

    if config.get("arch", None) is not None:
        filename += f"_arch{config['arch']}"
    else:
        filename += f"_shuf{config['shuffle']}"
        filename += f"_lr{config['lr']}"
        filename += f"_do{config['dropout']}"
        filename += f"_embd{config['n_embd']}"
        filename += f"_layer{config['n_layer']}"
        filename += f"_head{config['n_head']}"

    filename += f"_envs{config['n_envs']}"
    filename += f"_hists{config['n_hists']}"
    filename += f"_samples{config['n_samples']}"
    filename += f"_var{config['var']}"
    filename += f"_cov{config['cov']}"
    filename += f"_H{config['H']}"
    filename += f"_d{config['dim']}"
    filename += f"_seed{config['seed']}"
    if config.get("corrupt_train", "") != "":
        filename += f"_corr_train{config['corrupt_train']}"

    return filename


def build_linear_bandit_data_filename(env, n_envs, config, mode):
    """
    Builds the filename for the bandit data.
    Mode is either 0: train, 1: test, 2: eval.
    """
    filename = "datasets/trajs_"
    filename += env
    filename += f"_envs{n_envs}"
    if mode != 2:
        filename += f"_hists{config['n_hists']}"
        filename += f"_samples{config['n_samples']}"
    filename += f"_H{config['H']}"
    filename += f"_d{config['dim']}"
    filename += f"_lind{config['lin_d']}"
    filename += f"_var{config['var']}"
    filename += f"_cov{config['cov']}"
    if mode == 0:
        filename += "_train"
    elif mode == 1:
        filename += "_test"
    elif mode == 2:
        filename += "_eval"
    return filename


def build_linear_bandit_model_filename(env, config):
    """
    Builds the filename for the bandit model.
    """
    filename = env
    filename += f"_shuf{config['shuffle']}"
    filename += f"_lr{config['lr']}"
    filename += f"_do{config['dropout']}"
    filename += f"_embd{config['n_embd']}"
    filename += f"_layer{config['n_layer']}"
    filename += f"_head{config['n_head']}"
    filename += f"_envs{config['n_envs']}"
    filename += f"_hists{config['n_hists']}"
    filename += f"_samples{config['n_samples']}"
    filename += f"_var{config['var']}"
    filename += f"_cov{config['cov']}"
    filename += f"_H{config['H']}"
    filename += f"_d{config['dim']}"
    filename += f"_lind{config['lin_d']}"
    filename += f"_seed{config['seed']}"
    return filename


def build_darkroom_data_filename(env, n_envs, config, mode, reward_type: str = "sparse"):
    """
    Builds the filename for the darkroom data.
    Mode is either 0: train, 1: test, 2: eval.
    """
    filename = "datasets/trajs_"
    filename += env
    filename += f"_envs{n_envs}"
    if mode != 2:
        filename += f"_hists{config['n_hists']}"
        filename += f"_samples{config['n_samples']}"
    filename += f"_H{config['H']}"
    filename += f"_d{config['dim']}"
    filename += f"_R{reward_type}"
    if mode == 0:
        filename += "_train"
    elif mode == 1:
        filename += "_test"
    elif mode == 2:
        filename += "_" + config["rollin_type"]
        filename += "_eval"

    return filename


def build_darkroom_model_filename(env, config, reward_type="sparse"):
    """
    Builds the filename for the darkroom model.
    """
    filename = env
    filename += f"_shuf{config['shuffle']}"
    filename += f"_lr{config['lr']}"
    filename += f"_do{config['dropout']}"
    filename += f"_embd{config['n_embd']}"
    filename += f"_layer{config['n_layer']}"
    filename += f"_head{config['n_head']}"
    filename += f"_envs{config['n_envs']}"
    filename += f"_hists{config['n_hists']}"
    filename += f"_samples{config['n_samples']}"
    filename += f"_H{config['H']}"
    filename += f"_d{config['dim']}"
    filename += f"_R{reward_type}"
    filename += f"_seed{config['seed']}"
    return filename


def build_miniworld_data_filename(env, env_id_start, env_id_end, config, mode):
    """
    Builds the filename for the miniworld data.
    Mode is either 0: train, 1: test, 2: eval.
    """
    filename = "datasets/trajs_"
    filename += env
    filename += f"_start{env_id_start}_end{env_id_end}"
    if mode != 2:
        filename += f"_hists{config['n_hists']}"
        filename += f"_samples{config['n_samples']}"
    filename += f"_H{config['H']}"
    if mode == 0:
        filename += "_train"
    elif mode == 1:
        filename += "_test"
    elif mode == 2:
        filename += f"_{config['rollin_type']}"
        filename += "_eval"
    return filename


def build_miniworld_model_filename(env, config):
    """
    Builds the filename for the miniworld model.
    """
    filename = env
    filename += f"_shuf{config['shuffle']}"
    filename += f"_lr{config['lr']}"
    filename += f"_do{config['dropout']}"
    filename += f"_embd{config['n_embd']}"
    filename += f"_layer{config['n_layer']}"
    filename += f"_head{config['n_head']}"
    filename += f"_envs{config['n_envs']}"
    filename += f"_hists{config['n_hists']}"
    filename += f"_samples{config['n_samples']}"
    filename += f"_H{config['H']}"
    filename += f"_seed{config['seed']}"
    return filename


def convert_to_tensor(x, store_gpu=True):
    if store_gpu:
        return torch.tensor(np.asarray(x)).float().to(device)
    else:
        return torch.tensor(np.asarray(x)).float()


class DotDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
