import torch
import numpy as np

from fast_scaler import FastScaler
from nn_util import load_expert_data
from torch.utils.checkpoint import checkpoint

def forward_with_checkpoint(module, *args):
    return checkpoint(module, *args, use_reentrant=False)

def set_attributes_from_args(obj, default_config, args):
    # Args optionally contains a config dictionary
    # Hierarchy goes default < config < explicitly provided kwargs

    # First just populate args with the default values
    curr_args = default_config.copy()

    # Extract and remove config dict if present
    config_dict = args.pop("config", {})

    for key, value in config_dict.items():
        if key in curr_args:
            curr_args[key] = value
        else:
            print(f"Key {key} not recognized!")

    for key, value in args.items():
        if key in curr_args:
            curr_args[key] = args[key]
        else:
            print(f"Key {key} not recognized!")

    for key, value in curr_args.items():
        assert value != None, f"{key} must be explicitly set, it has no default!"
        setattr(obj, key, value)

def get_scalers_from_data_path(path, dan=False, dan_no_delta=False, dan_no_action=False, dan_delta_magnitude=False, dan_scalar_output=False):
    expert_data = load_expert_data(path)

    obs_scaler = FastScaler()
    obs_scaler.fit(np.concatenate([traj['observations'] for traj in expert_data]))

    act_scaler = FastScaler()
    act_scaler.fit(np.concatenate([traj['actions'] for traj in expert_data]))

    # Need to change obs_scaler to accept [already normalized state, action, already normalized delta_state]
    if dan:
        action_mean = (np.array([], dtype=obs_scaler.mean_np.dtype) if dan_no_action else act_scaler.mean_np)
        delta_mean = (np.array([0], dtype=obs_scaler.mean_np.dtype) if dan_delta_magnitude else np.zeros_like(obs_scaler.mean_np))

        obs_scaler.mean_np = np.hstack((np.zeros_like(obs_scaler.mean_np), action_mean, delta_mean))

        action_std = np.array([], dtype=obs_scaler.scale_np.dtype) if dan_no_action else act_scaler.scale_np
        delta_std = np.ones_like(obs_scaler.scale_np) * (1 if dan_no_delta else np.sqrt(2))
        if dan_delta_magnitude:
            delta_std = np.linalg.norm(delta_std)
        obs_scaler.scale_np = np.hstack((np.ones_like(obs_scaler.scale_np), action_std, delta_std))
        obs_scaler.mean_torch = torch.as_tensor(obs_scaler.mean_np)
        obs_scaler.scale_torch = torch.as_tensor(obs_scaler.scale_np)

    if dan_scalar_output:
        act_scaler = FastScaler()
        act_scaler.mean_np = np.array([0], dtype=obs_scaler.mean_np.dtype)
        act_scaler.scale_np = np.array([1], dtype=obs_scaler.mean_np.dtype)
        act_scaler.mean_torch = torch.as_tensor(act_scaler.mean_np)
        act_scaler.scale_torch = torch.as_tensor(act_scaler.scale_np)

    return obs_scaler, act_scaler

def get_io_size_from_data_path(path, classifier=False, dan=False, dan_no_action=False, dan_delta_magnitude=False, dan_scalar_output=False):
    expert_data = load_expert_data(path)

    if classifier:
        # Assume actions are class indices
        max_action_class = 0
        for traj in expert_data:
            max_action_class = np.max((max_action_class, np.max(traj['actions'])))

        return len(expert_data[0]['observations'][0]), max_action_class + 1
    elif dan:
        return (len(expert_data[0]['observations'][0]) * (1 if dan_delta_magnitude else 2)) + (1 if dan_delta_magnitude else 0) + (0 if dan_no_action else len(expert_data[0]['actions'][0])), 1 if dan_scalar_output else len(expert_data[0]['actions'][0])
    else:
        return len(expert_data[0]['observations'][0]), len(expert_data[0]['actions'][0])

def get_action_min_max_len(path):
    expert_data = load_expert_data(path)

    actions = np.concatenate([traj['actions'] for traj in expert_data])

    return torch.tensor(np.min(actions, axis=0)), torch.tensor(np.max(actions, axis=0)), len(actions[0])

def hash_tensor(tensor: torch.Tensor):
    import hashlib

    return hashlib.sha256(tensor.cpu().detach().numpy().tobytes()).hexdigest()
    #return hash(tensor.cpu().detach().numpy().tobytes())

imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

