import torch
import argparse
import sys
import os
import pathlib
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset
import hydra
from omegaconf import DictConfig, OmegaConf
import dill
from einops import rearrange, repeat
import json
from diffusion_policy.dataset.pusht_image_dynamics_dataset import DynamicsModelDataset
from diffusion_policy.workspace.base_workspace import BaseWorkspace


class ResNetEncoder(torch.nn.Module):
    def __init__(self, policy_ckpt_path, view_names, ):
        super().__init__()
        self.policy_ckpt_path = policy_ckpt_path
        self.view_names = view_names
        self.emb_dim = 512
        self.latent_ndim = 2
        self.name = 'resnet'

        with open(self.policy_ckpt_path, 'rb') as f:
            payload = torch.load(f, pickle_module=dill)
            cfg = payload['cfg']
        cls = hydra.utils.get_class(cfg._target_)
        workspace = cls(payload['cfg'], output_dir='debug_obs_encoder')
        workspace: BaseWorkspace
        workspace.load_payload(payload, exclude_keys=None, include_keys=None)

        policy = workspace.model
        if cfg.training.use_ema:
            policy = workspace.ema_model

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        policy.to(device)
        policy.eval()

        self.obs_encoder = {}
        for view_name in self.view_names:
            self.obs_encoder[view_name] = policy.obs_encoder.obs_nets[view_name].backbone
            # print(self.obs_encoder[view_name])
            # print(self.obs_encoder[view_name].conv1.weight)        # the full weight tensor
            # print(self.obs_encoder[view_name].conv1.weight.shape)  # shape only
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        
        # obs_encoder_robot0_eye_in_hand_view = policy.obs_encoder.obs_nets.robot0_eye_in_hand_image.backbone
        # obs_encoder_agentview = policy.obs_encoder.obs_nets.agentview_image.backbone

        del workspace
        del policy
        torch.cuda.empty_cache()

    def forward(self, x):
        view_embs = {}
        for view_name in self.view_names:
            imgs = x[view_name]
            b = imgs.shape[0]
            imgs = rearrange(imgs, "b t ... -> (b t) ...")
            # print('imgs shape:', imgs.shape)
            assert imgs.shape[-1] == 128
            imgs_emb = self.obs_encoder[view_name](imgs)
            imgs_emb = self.avgpool(imgs_emb)
            imgs_emb = imgs_emb.squeeze(-1).squeeze(-1)
            imgs_emb = imgs_emb.unsqueeze(1) # dummy patch dim
            # print('imgs_emb shape:', imgs_emb.shape)
            imgs_emb = rearrange(imgs_emb, "(b t) p d -> b t p d", b=b)
            view_embs[view_name] = imgs_emb
            # print(f"{view_name}: {imgs_emb.shape}")
        return view_embs


