from collections.abc import Sequence
import dataclasses
import logging
import pathlib
from typing import Any

import etils.epath as epath

import openpi.models.model as _model
import openpi.policies.policy as _policy
import openpi.shared.download as download
from openpi.training import config as _config
from openpi.models.pi0 import Pi0
import openpi.transforms as transforms

from safetensors.torch import load_file
import torch

@dataclasses.dataclass
class PolicyConfig:
    model: Pi0
    norm_stats: dict[str, transforms.NormStats]

    input_layers: Sequence[transforms.DataTransformFn]
    output_layers: Sequence[transforms.DataTransformFn]

    model_type: _model.ModelType = _model.ModelType.PI0
    default_prompt: str | None = None
    sample_kwargs: dict[str, Any] | None = None


def create_trained_policy(
    train_config: _config.TrainConfig,
    checkpoint_dir: pathlib.Path | str,
    *,
    repack_transforms: transforms.Group | None = None,
    sample_kwargs: dict[str, Any] | None = None,
    default_prompt: str | None = None,
    norm_stats: dict[str, transforms.NormStats] | None = None,
) -> _policy.Policy:
    """Create a policy from a trained checkpoint.

    Args:
        train_config: The training config to use to create the model.
        checkpoint_dir: The directory to load the model from.
        repack_transforms: Optional transforms that will be applied before any other transforms.
        sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
            kwargs will be used.
        default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
            data if it doesn't already exist.
        norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
            from the checkpoint directory.
    """
    repack_transforms = repack_transforms or transforms.Group()
    checkpoint_dir = download.maybe_download(str(checkpoint_dir))

    logging.info("Loading model...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = train_config.model.create()
    # model.load_state_dict(load_file(checkpoint_dir), strict=True)
    model.load_state_dict(torch.load(checkpoint_dir), strict=True)
    model.to(device)
    
    data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
    if norm_stats is None:
        # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
        # that the policy is using the same normalization stats as the original training process.
        
        # norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
        # print("epath.Path(train_config.assets_dirs), data_config.asset_id", epath.Path(train_config.assets_dirs), data_config.asset_id)
        # assert 0 == 1
        norm_stats = train_config.data._load_norm_stats(epath.Path(train_config.assets_dirs), data_config.asset_id)
        if norm_stats is None:
            raise ValueError(
                f"Could not load norm stats from {epath.Path(train_config.assets_dirs), data_config.asset_id}. "
                "Please make sure the norm stats are available in the assets directory."
            )

    return _policy.Policy(
        model,
        transforms=[
            *repack_transforms.inputs,
            transforms.InjectDefaultPrompt(default_prompt),
            *data_config.data_transforms.inputs,
            transforms.Normalize(data_mask=data_config.data_mask, norm_stats=data_config.norm_stats),
            *data_config.model_transforms.inputs,
        ],
        output_transforms=[
            *data_config.model_transforms.outputs,
            transforms.Unnormalize(data_mask=data_config.data_mask, norm_stats=data_config.norm_stats),
            *data_config.data_transforms.outputs,
            *repack_transforms.outputs,
        ],
        sample_kwargs=sample_kwargs,
        metadata=train_config.policy_metadata,
    )
