import robomimic.utils.tensor_utils as TensorUtils
import torch
import hydra
import omegaconf
import torch.nn as nn
import sys
import os
import gdown

from libero.lifelong.models.modules.rgb_modules import *
from libero.lifelong.models.modules.language_modules import *
from libero.lifelong.models.modules.transformer_modules import *
from libero.lifelong.models.base_policy import BasePolicy
from libero.lifelong.models.policy_head import *
from r3m import remove_language_head, cleanup_config
from collections import defaultdict

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


###############################################################################
#
# A model handling extra input modalities besides images at time t.
#
###############################################################################


class ExtraModalityTokens(nn.Module):
    def __init__(
        self,
        use_joint=False,
        use_gripper=False,
        use_ee=False,
        extra_num_layers=0,
        extra_hidden_size=64,
        extra_embedding_size=32,
    ):
        """
        This is a class that maps all extra modality inputs into tokens of the same size
        """
        super().__init__()
        self.use_joint = use_joint
        self.use_gripper = use_gripper
        self.use_ee = use_ee
        self.extra_embedding_size = extra_embedding_size

        joint_states_dim = 7
        gripper_states_dim = 2
        ee_dim = 3

        self.num_extra = int(use_joint) + int(use_gripper) + int(use_ee)

        extra_low_level_feature_dim = (
            int(use_joint) * joint_states_dim
            + int(use_gripper) * gripper_states_dim
            + int(use_ee) * ee_dim
        )

        assert extra_low_level_feature_dim > 0, "[error] no extra information"

        self.extra_encoders = {}

        def generate_proprio_mlp_fn(modality_name, extra_low_level_feature_dim):
            assert extra_low_level_feature_dim > 0  # we indeed have extra information
            if extra_num_layers > 0:
                layers = [nn.Linear(extra_low_level_feature_dim, extra_hidden_size)]
                for i in range(1, extra_num_layers):
                    layers += [
                        nn.Linear(extra_hidden_size, extra_hidden_size),
                        nn.ReLU(inplace=True),
                    ]
                layers += [nn.Linear(extra_hidden_size, extra_embedding_size)]
            else:
                layers = [nn.Linear(extra_low_level_feature_dim, extra_embedding_size)]

            self.proprio_mlp = nn.Sequential(*layers)
            self.extra_encoders[modality_name] = {"encoder": self.proprio_mlp}

        for (proprio_dim, use_modality, modality_name) in [
            (joint_states_dim, self.use_joint, "joint_states"),
            (gripper_states_dim, self.use_gripper, "gripper_states"),
            (ee_dim, self.use_ee, "ee_states"),
        ]:

            if use_modality:
                generate_proprio_mlp_fn(modality_name, proprio_dim)

        self.encoders = nn.ModuleList(
            [x["encoder"] for x in self.extra_encoders.values()]
        )

    def forward(self, obs_dict):
        """
        obs_dict: {
            (optional) joint_stats: (B, T, 7),
            (optional) gripper_states: (B, T, 2),
            (optional) ee: (B, T, 3)
        }
        map above to a latent vector of shape (B, T, H)
        """
        tensor_list = []

        for (use_modality, modality_name) in [
            (self.use_joint, "joint_states"),
            (self.use_gripper, "gripper_states"),
            (self.use_ee, "ee_states"),
        ]:

            if use_modality:
                tensor_list.append(
                    self.extra_encoders[modality_name]["encoder"](
                        obs_dict[modality_name]
                    )
                )

        x = torch.stack(tensor_list, dim=-2)
        return x


###############################################################################
#
# A Transformer Policy
#
###############################################################################


class BCTransformerPolicyR3M(BasePolicy):
    """
    Input: (o_{t-H}, ... , o_t)
    Output: a_t or distribution of a_t
    """

    def __init__(self, cfg, shape_meta):
        super().__init__(cfg, shape_meta)
        self.rollouts_embeddings = defaultdict(list)
        policy_cfg = cfg.policy
        self.r3m = self.load_r3m().eval()

        ### 1. encode image
        embed_size = policy_cfg.embed_size
        self.image_encoders = nn.ModuleDict()
        for name in shape_meta["all_shapes"].keys():
            if "rgb" in name or "depth" in name:
                self.image_encoders[name] = nn.Sequential(
                    nn.Linear(512, 128),
                    nn.ReLU(inplace=True),
                    nn.Linear(128, embed_size)
                )

        ### 2. encode language
        policy_cfg.language_encoder.network_kwargs.output_size = embed_size
        self.language_encoder = eval(policy_cfg.language_encoder.network)(
            **policy_cfg.language_encoder.network_kwargs
        )

        ### 3. encode extra information (e.g. gripper, joint_state)
        self.extra_encoder = ExtraModalityTokens(
            use_joint=cfg.data.use_joint,
            use_gripper=cfg.data.use_gripper,
            use_ee=cfg.data.use_ee,
            extra_num_layers=policy_cfg.extra_num_layers,
            extra_hidden_size=policy_cfg.extra_hidden_size,
            extra_embedding_size=embed_size,
        )

        ### 4. define temporal transformer
        policy_cfg.temporal_position_encoding.network_kwargs.input_size = embed_size
        self.temporal_position_encoding_fn = eval(
            policy_cfg.temporal_position_encoding.network
        )(**policy_cfg.temporal_position_encoding.network_kwargs)

        self.temporal_transformer = TransformerDecoder(
            input_size=embed_size,
            num_layers=policy_cfg.transformer_num_layers,
            num_heads=policy_cfg.transformer_num_heads,
            head_output_size=policy_cfg.transformer_head_output_size,
            mlp_hidden_size=policy_cfg.transformer_mlp_hidden_size,
            dropout=policy_cfg.transformer_dropout,
        )

        policy_head_kwargs = policy_cfg.policy_head.network_kwargs
        policy_head_kwargs.input_size = embed_size
        policy_head_kwargs.output_size = shape_meta["ac_dim"]

        self.policy_head = eval(policy_cfg.policy_head.network)(
            **policy_cfg.policy_head.loss_kwargs,
            **policy_cfg.policy_head.network_kwargs
        )

        self.latent_queue = []
        self.max_seq_len = policy_cfg.transformer_max_seq_len

    def temporal_encode(self, x):
        pos_emb = self.temporal_position_encoding_fn(x)
        x = x + pos_emb.unsqueeze(1)  # (B, T, num_modality, E)
        sh = x.shape
        self.temporal_transformer.compute_mask(x.shape)

        x = TensorUtils.join_dimensions(x, 1, 2)  # (B, T*num_modality, E)
        x = self.temporal_transformer(x)
        x = x.reshape(*sh)
        return x[:, :, 0]  # (B, T, E)

    def spatial_encode(self, data, save_emb=False):
        # 1. encode extra
        extra = self.extra_encoder(data["obs"])  # (B, T, num_extra, E)

        # 2. encode language, treat it as action token
        B, T = extra.shape[:2]
        text_encoded = self.language_encoder(data)  # (B, E)
        text_encoded = text_encoded.view(B, 1, 1, -1).expand(
            -1, T, -1, -1
        )  # (B, T, 1, E)
        encoded = [text_encoded, extra]

        # 3. encode image, using pretrained R3M
        for img_name in self.image_encoders.keys():
            x = data["obs"][img_name]
            B, T, C, H, W = x.shape
            with torch.no_grad():
                y = self.r3m(x.reshape(B * T, C, H, W) * 255.0)
                if save_emb:
                    if img_name == "agentview_rgb":
                        for i in range(B):
                            # y[i].cpu().numpy().shape = (512,)
                            self.rollouts_embeddings[i].append(y[i].unsqueeze(0))
            img_encoded_r3m = self.image_encoders[img_name](y).view(B, T, 1, -1)
            encoded.append(img_encoded_r3m)


        encoded = torch.cat(encoded, -2)  # (B, T, num_modalities, E)
        return encoded

    # def compare_model_params(self, model1, model2):
    #     # Compare the parameters
    #     for (param1_name, param1), (param2_name, param2) in zip(model1.named_parameters(), model2.named_parameters()):
    #         if not torch.equal(param1, param2):
    #             print(f"Difference found in parameter: {param1_name}")
    #             return False
    #
    #     # Compare the running statistics in BatchNorm layers
    #     for (name1, module1), (name2, module2) in zip(model1.named_modules(), model2.named_modules()):
    #         if isinstance(module1, torch.nn.BatchNorm2d):
    #             if not torch.equal(module1.running_mean, module2.running_mean):
    #                 print(f"Difference found in running_mean of BatchNorm layer: {name1}")
    #                 print(f"Model1 running_mean: {module1.running_mean}")
    #                 print(f"Model2 running_mean: {module2.running_mean}")
    #                 return False
    #             if not torch.equal(module1.running_var, module2.running_var):
    #                 print(f"Difference found in running_var of BatchNorm layer: {name1}")
    #                 print(f"Model1 running_var: {module1.running_var}")
    #                 print(f"Model2 running_var: {module2.running_var}")
    #                 return False
    #
    #     print("No differences found in parameters or BatchNorm running statistics.")
    #     return True

    def load_r3m(self):
        model_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../models/r3m"))
        config_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../IL/configs"))
        modelpath = os.path.join(model_dir, "model_resnet18.pt")
        configpath = os.path.join(config_dir, "config.yaml")

        if not os.path.exists(modelpath):
            os.makedirs(model_dir)
            modelurl = 'https://drive.google.com/uc?id=1A1ic-p4KtYlKXdXHcV2QV0cUzI4kn0u-'
            configurl = 'https://drive.google.com/uc?id=1nitbHQ-GRorxc7vMUiEHjHWP5N11Jvc6'

            gdown.download(modelurl, modelpath, quiet=False)
            gdown.download(configurl, configpath, quiet=False)

        modelcfg = omegaconf.OmegaConf.load(configpath)
        cleancfg = cleanup_config(modelcfg)
        rep = hydra.utils.instantiate(cleancfg)
        rep = torch.nn.DataParallel(rep)

        r3m_state_dict = remove_language_head(torch.load(modelpath)['r3m'])
        rep.load_state_dict(r3m_state_dict)

        return rep.module

    # during training
    def forward(self, data):
        x = self.spatial_encode(data)
        x = self.temporal_encode(x)
        dist = self.policy_head(x)
        return dist

    # during evaluation and testing
    def get_action(self, data, save_emb=False):
        self.eval()
        with torch.no_grad():
            data = self.preprocess_input(data, train_mode=False)
            x = self.spatial_encode(data, save_emb)
            self.latent_queue.append(x)
            if len(self.latent_queue) > self.max_seq_len:
                self.latent_queue.pop(0)
            x = torch.cat(self.latent_queue, dim=1)  # (B, T, H_all)
            x = self.temporal_encode(x)
            dist = self.policy_head(x[:, -1])
        action = dist.sample().detach().cpu()
        return action.view(action.shape[0], -1).numpy()

    def reset(self):
        self.latent_queue = []
