import os
import warnings
import random

import numpy as np
import torch
from omegaconf import open_dict

from ltsgns_mp.algorithms import get_algorithm
from ltsgns_mp.envs import get_env, Env
from ltsgns_mp.evaluation import get_evaluator
from ltsgns_mp.recording.recorder import Recorder


def main_initialization(config, env: Env | None = None):
    """
    Initializes the config, the seed, the device and the env, algorithm, evaluator and recorder.
    :param config:
    :param env: If None, the env will be initialized. Otherwise, the env will be used.
    :return:
    """
    initialize_config(config)  # put slurm stuff into config
    initialize_seed(config)  # set seed
    device = initialize_and_get_device(config)  # get device. If "cuda", check if cuda is available and use cpu if not

    if env is None:
        env = get_env(config.env, train_iterator_config=config.algorithm.train_iterator,
                      evaluation_config=config.evaluation, device=device)
    algorithm = get_algorithm(config=config.algorithm, env=env, loading_config=config.loading, device=device)
    evaluator = get_evaluator(config=config.evaluation.evaluator, algorithm=algorithm, env=env, eval_only=config.evaluation.eval_only, seed=config.seed)
    recorder = Recorder(config=config, algorithm=algorithm)
    return env, algorithm, evaluator, recorder


def initialize_config(config):
    try:
        with open_dict(config):
            config.slurm_array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", None)
            config.slurm_job_id = os.environ.get("SLURM_JOB_ID", None)
    except KeyError:
        pass


def initialize_seed(config):
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)


def initialize_and_get_device(config) -> str:
    if config.device == "cpu":
        device = "cpu"
    elif config.device == "cuda":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        warnings.warn(f"Unknown device: {config.device}")
        warnings.warn("Using cpu instead...")
        device = "cpu"

    if device == "cuda":
        # deterministic setup for reproducibility
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    return device
