import os
import yaml
import torch
from torch.cuda.amp import autocast
import numpy
import random
from pprint import pprint
from dotmap import DotMap
from tensordict import unravel_key
from torchrl.envs import Transform


def makedirs(dir_list):
    for dir in dir_list:
        if not os.path.exists(dir):
            os.makedirs(dir)

def load_config(config_name):
    # check that yaml file exists
    if not os.path.exists(config_name):
        raise NameError("YAML configuration file does not exist, exiting!")

    # load the config yaml
    with open(config_name) as f:
        config_yaml = yaml.safe_load(f)

    config = DotMap(config_yaml)
    return config

def process_config(config_name):
    config = load_config(config_name)
    
    print("Loaded configuration: ")
    pprint(config)

    print()
    print(" *************************************** ")
    print("      Running experiment {}".format(config.exp_name))
    print(" *************************************** ")
    print()

    return config

def initialize_save_data(config):
    if config.logger.save_data:
        makedirs([config.logger.output_dir]) # will not create if already existing

        # Save the config to the exp_dir
        config_out = os.path.join(config.logger.output_dir, 'config.yaml')
        with open(config_out, 'w') as outfile:
            yaml.dump(config.toDict(), outfile)

def seed_everything(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    numpy.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def swap_last(source, dest):
    source = unravel_key(source)
    dest = unravel_key(dest)
    if isinstance(source, str):
        if isinstance(dest, str):
            return dest
        return dest[-1]
    if isinstance(dest, str):
        return source[:-1] + (dest,)
    return source[:-1] + (dest[-1],)

class DoneTransform(Transform):
    """Expands the 'done' entries (incl. terminated) to match the reward shape.

    Can be appended to a replay buffer or a collector.
    """

    def __init__(self, reward_key, done_keys):
        super().__init__()
        self.reward_key = reward_key
        self.done_keys = done_keys

    def forward(self, tensordict):
        for done_key in self.done_keys:
            new_name = swap_last(self.reward_key, done_key)
            tensordict.set(
                ("next", new_name),
                tensordict.get(("next", done_key))
                .unsqueeze(-1)
                .expand(tensordict.get(("next", self.reward_key)).shape),
            )
        return tensordict