from collections import defaultdict
import torch
import uuid
from datetime import datetime
from tensordict.nn import TensorDictModule

from torchrl.modules import (
    ProbabilisticActor,
    ValueOperator,
    ActorCriticWrapper,
    OneHotCategorical
)

from models.actor_critic import Actor, Critic, SharedFeatureExtractor
from models.actor_critic import DisretizedProbabilisticActor, SeparateDisretizedProbabilisticActor
from minigrid.core.constants import OBJECT_TO_IDX

AGENT_IDX = OBJECT_TO_IDX["agent"]
KEY_IDX = OBJECT_TO_IDX["key"]

def generate_exp_name(model_name: str, experiment_name: str) -> str:
    """Generates an ID (str) for the described experiment using UUID and current date."""
    exp_name = "-".join(
        (
            model_name,
            experiment_name,
            str(uuid.uuid4())[:8],
            datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
        )
    )
    return exp_name

def get_device(config):
    print(config['device'] )
    if torch.cuda.is_available() and config['device'] is not None:
        if config['device'] != -1:
            device = torch.device(config['device'])
        else:
            device = torch.device("cuda")
    elif torch.backends.mps.is_available() and config['device'] !=0:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f'Running on device: {device}')
    return device


def create_actor_critic(config, device):
    input_key_dict = {
        'multi_grid': 'image',
        'two_goal_grid': 'image',
        'single_cook': 'image',
        'metaworld': 'observation'
    }

    if config['env'] in ["single_cook"]:
        using_recipe=True
    else:
        using_recipe=False


    critic_share_layers = config['model']['custom_model_config'].get('critic_share_layers', False)

    print(f"sharing layers {critic_share_layers}")

    if critic_share_layers:
        print("Using shared feature extractor for Actor and Critic (except final layers).")
        shared_feature_extractor = SharedFeatureExtractor(config).to(device)
        actor_model = Actor(config, shared_feature_extractor, using_recipe=using_recipe)
        critic_model = Critic(config, shared_feature_extractor, using_recipe=using_recipe)
    else:
        print("Using separate networks for Actor and Critic.")
        actor_model = Actor(config, using_recipe=using_recipe)
        critic_model = Critic(config, using_recipe=using_recipe)

    # Create the Actor module
    if config['env'] in ["single_cook"]:
        print("recipe head used")
        actor_module = TensorDictModule(
            actor_model,
            in_keys={input_key_dict[config['env']]: 'obs', 'recipe': 'recipe', 'action_mask': 'action_mask'},
            out_keys=['logits', 'raw_logits', 'energy', 'raw_energy']
        )

        critic = ValueOperator(
            critic_model, in_keys={input_key_dict[config['env']]: 'obs', 'recipe': 'recipe'}, out_keys=['state_value']
        )
    else:
        print("recipe head not used")
        actor_module = TensorDictModule(
            actor_model,
            in_keys={input_key_dict[config['env']]: 'obs', 'action_mask': 'action_mask'},
            out_keys=['logits', 'raw_logits', 'energy', 'raw_energy']
        )

        critic = ValueOperator(
            critic_model, in_keys={input_key_dict[config['env']]: 'obs'}, out_keys=['state_value']
        )

    if config['env'] not in ['metaworld']:
        actor = ProbabilisticActor(
            actor_module, in_keys=["logits"], distribution_class=OneHotCategorical, return_log_prob=True
        )
    else:
        from environments.env_loader import env_maker
        import numpy as np

        env = env_maker(config)
        action_space = env.action_space
        action_dim = action_space.shape[0]
        size = config['model']['custom_model_config']['discretize_size']

        if config['model']['custom_model_config']['discretize_mode'] == 'exp':
            values = []
            for i in range(action_dim):
                values.append(np.linspace(action_space.low[i], action_space.high[i], size))
            grid = np.meshgrid(*values)
            action_mapping = torch.from_numpy(
                np.vstack([np.ravel(g) for g in grid]).T
            )

            actor = DisretizedProbabilisticActor(
                actor_module, in_keys=["logits"], distribution_class=OneHotCategorical, return_log_prob=True,
                action_mapping=action_mapping
            )
        elif config['model']['custom_model_config']['discretize_mode'] == 'linear':
            action_mapping = torch.stack(
                [torch.linspace(action_space.low[i], action_space.high[i], size) for i in range(action_dim)],
                dim=0
            )
            actor = SeparateDisretizedProbabilisticActor(
                actor_module, in_keys=['logits'], distribution_class=OneHotCategorical, return_log_prob=True,
                action_mapping=action_mapping
            )
        else:
            raise ValueError('unknown discretize mode:', config['model']['custom_model_config']['discretize_mode'])


    actor_critic = ActorCriticWrapper(actor, critic)

    actor, critic = actor_critic.get_policy_operator().to(device), actor_critic.get_value_operator().to(device)
    print("Actor Model:")
    print(actor)

    print("Critic Model:")
    print(critic)

    return actor, critic


def eval_model(actor, test_env, num_episodes=10, max_steps=500):
    cur_eps = 0
    results = defaultdict(list)
    while cur_eps < num_episodes:
        td_test = test_env.rollout(
            policy=actor,
            auto_cast_to_device=True,
            break_when_any_done=False,
            max_steps=max_steps,
        )
        reward = td_test["next", "reward"][td_test["next", "done"]]
        episode_length = td_test["next", "step_count"][td_test["next", "done"]]
        if "info" in td_test.keys():
            info = td_test["info"][td_test["next", "done"]]

        results['reward'].append(reward.cpu())
        results['episode_length'].append(episode_length.cpu())
        if "info" in td_test.keys():
            results['sparse'].append(info["sparse_r_by_agent"].cpu())
            # print(info["sparse_r_by_agent"])
            results['shape'].append(info["shaped_r_by_agent"].cpu())
        if 'success' in td_test.keys():
            outcomes = td_test['next', 'success'][..., None][td_test["next", "done"]]
            results['success_rate'].append(outcomes.cpu())

        cur_eps += len(td_test)
    
    results = {f'eval/{k}': torch.cat(v, 0).float().mean().item() for k, v in results.items()}
    
    return results


def group_episodes_by_outcome(data):
    success = []
    failure = []
    for eps in data:
        stacked_eps = torch.stack(eps, 0)
        if len(stacked_eps["next", "reward"][stacked_eps["next", "terminated"]]) > 0:
            success.append(eps)
        else:
            failure.append(eps)
    print(f'success rate: {len(success) / len(data)}')
    return {'success': success, 'failure': failure}


def get_env_mask(images: torch.Tensor, name: str, recipe=None, specific_name=None) -> torch.Tensor:
    """
    Determines whether each image encoding corresponds to the old or new environment.

    Args:
        images (torch.Tensor): Tensor of shape [D, C, H, W]
        recipe (torch.Tensor, optional): Tensor of shape [D, 2] if provided

    Returns:
        torch.Tensor: mask of shape [D], where 1 = old env, 0 = new env

    Raises:
        ValueError: if any image or recipe is invalid
    """
    if name == "multi_grid":
        assert images.dim() == 4 and images.shape[1] == 3, "Expected shape [D, 3, 9, 9]"

        object_channel = images[:, 0]  # [D, 9, 9]
        D = object_channel.shape[0]

        if specific_name == "open_target_grid":
            goal_mask = object_channel == 8  # goal ID is 8
            env_mask = torch.empty((D,), dtype=torch.int)

            for i in range(D):
                pos = torch.nonzero(goal_mask[i], as_tuple=False)

                if pos.numel() == 0:
                    raise ValueError(f"No goal found in image {i}")

                x, y = pos[0].tolist()  # use the first goal location

                if x in [1, 2, 3] and y in [1, 2, 3]:
                    env_mask[i] = 1  # old env
                elif x in [5, 6, 7] and y in [5, 6, 7]:
                    env_mask[i] = 0  # new env
                else:
                    raise ValueError(f"Goal found at invalid position ({x}, {y}) in image {i}")

            return env_mask

        else:
            object_layer = object_channel  # alias for clarity
            env_mask = torch.zeros((D,), dtype=torch.int)

            for i in range(D):
                has_key = (object_layer[i] == KEY_IDX).any().item()
                env_mask[i] = 1 if has_key else 0

            return env_mask

    elif name == "single_cook":
        assert images.dim() == 4, "Expected shape [D, C, H, W]"
        D = images.shape[0]

        if recipe is not None:
            assert recipe.shape == (D, 3), f"Expected recipe shape [D, 3], got {recipe.shape}"

            env_mask = torch.empty((D,), dtype=torch.int)

            for i in range(D):
                if torch.all(recipe[i] == torch.tensor([1, 0, 0], dtype=recipe.dtype, device=recipe.device)):
                    env_mask[i] = 0  # onion
                elif torch.all(recipe[i] == torch.tensor([0, 1, 0], dtype=recipe.dtype, device=recipe.device)):
                    env_mask[i] = 1  # tomato
                elif torch.all(recipe[i] == torch.tensor([0, 0, 1], dtype=recipe.dtype, device=recipe.device)):
                    env_mask[i] = 2  # fish
                else:
                    raise ValueError(f"Invalid recipe at sample {i}: {recipe[i].tolist()}")

            return env_mask


    else:
        raise ValueError(f"Unsupported environment name: {name}")

