from baseline_configs.rearrange_base import (
    RearrangeBaseExperimentConfig,
)
from allenact.base_abstractions.preprocessor import SensorPreprocessorGraph
from allenact.embodiedai.preprocessors.resnet import ResNetPreprocessor
#from allenact.embodiedai.preprocessors.resnet import ViTPreprocessor
#from allenact.embodiedai.preprocessors.vit import ViTPreprocessor
from allenact.base_abstractions.sensor import SensorSuite, Sensor, ExpertActionSensor
from torch import nn, cuda, optim
import gym.spaces
import torchvision.models

from rearrange.baseline_models import (
    RearrangeActorCriticSimpleConvRNN,
    ResNetRearrangeActorCriticRNN,
    ViTRearrangeActorCriticRNN,
    ImageNetViTRearrangeActorCriticRNN,
)
from rearrange.constants import (
    OBJECT_TYPES_WITH_PROPERTIES,
    THOR_COMMIT_ID,
)

# from pytorch_pretrained_vit import ViT


class RearrangeViTExperimentConfig(RearrangeBaseExperimentConfig):
    @classmethod
    def tag(cls) -> str:
        return f"OnePhaseRGBClipViTDagger_{cls.IL_PIPELINE_TYPE}"

    @classmethod
    def resnet_preprocessor_graph(cls, mode: str) -> SensorPreprocessorGraph:
        def create_resnet_builder(in_uuid: str, out_uuid: str):
            cnn_type, pretraining_type = cls.CNN_PREPROCESSOR_TYPE_AND_PRETRAINING
            
            if pretraining_type == "imagenet" and 'RN' in cnn_type:
                assert cnn_type in [
                    "RN18",
                    "RN50",
                ], "Only allow using RN18/RN50 with `imagenet` pretrained weights."
                return ResNetPreprocessor(
                    input_height=cls.THOR_CONTROLLER_KWARGS["height"],
                    input_width=cls.THOR_CONTROLLER_KWARGS["width"],
                    output_width=7,
                    output_height=7,
                    output_dims=512 if "18" in cnn_type else 2048,
                    pool=False,
                    torchvision_resnet_model=getattr(
                        torchvision.models, f"resnet{cnn_type.replace('RN', '')}"
                    ),
                    input_uuids=[in_uuid],
                    output_uuid=out_uuid,
                )
            elif pretraining_type == "imagenet" and 'ViT' in cnn_type:
                assert cnn_type in [
                    "ViT-B/32",
                    "ViT-B/16",
                ], "Only allow using ViT-B/32, ViT-B/16 with `imagenet` pretrained weights."
                return ViTPreprocessor(
                    input_height=cls.THOR_CONTROLLER_KWARGS["height"],
                    input_width=cls.THOR_CONTROLLER_KWARGS["width"],

                    output_width=1,
                    output_height=1,
                    #output_dims=512 if "18" in cnn_type else 2048,
                    output_dims=197 if '16' in cnn_type else 50,
                    input_uuids=[in_uuid],
                    output_uuid=out_uuid,
                    model_type = cnn_type
                )
            elif pretraining_type == "clip":
                from allenact_plugins.clip_plugin.clip_preprocessors import (
                    PromptClipViTPreprocessor,
                    ClipViTPreprocessor, HighpassClipPreprocessor
                )
                import clip

                # Let's make sure we download the clip model now
                # so we don't download it on every spawned process
                # clip.load(cnn_type, "cpu")
                # return PromptClipViTPreprocessor(
                return ClipViTPreprocessor(
                #return HighpassClipPreprocessor(
                    rgb_input_uuid=in_uuid,
                    clip_model_type=cnn_type,
                    class_emb_only=False,
                    output_uuid=out_uuid,
                )
            else:
                raise NotImplementedError

        img_uuids = [cls.EGOCENTRIC_RGB_UUID, cls.UNSHUFFLED_RGB_UUID]
        return SensorPreprocessorGraph(
            source_observation_spaces=SensorSuite(
                [
                    sensor
                    for sensor in cls.sensors()
                    if (mode == "train" or not isinstance(sensor, ExpertActionSensor))
                ]
            ).observation_spaces,
            preprocessors=[
                create_resnet_builder(sid, f"{sid}_resnet") for sid in img_uuids
            ],
        )
    @classmethod
    def create_model(cls, **kwargs) -> nn.Module:
        cnn_type, pretraining_type = cls.CNN_PREPROCESSOR_TYPE_AND_PRETRAINING
        if 'ViT' in cnn_type and pretraining_type == "imagenet":
            return ImageNetViTRearrangeActorCriticRNN(
                action_space=gym.spaces.Discrete(len(cls.actions())),
                observation_space=kwargs[
                    "sensor_preprocessor_graph"
                ].observation_spaces,
                rgb_uuid=cls.EGOCENTRIC_RGB_RESNET_UUID,
                unshuffled_rgb_uuid=cls.UNSHUFFLED_RGB_RESNET_UUID,
                cnn_type=cnn_type
                )
        if cls.CNN_PREPROCESSOR_TYPE_AND_PRETRAINING is None:
            return RearrangeActorCriticSimpleConvRNN(
                action_space=gym.spaces.Discrete(len(cls.actions())),
                observation_space=SensorSuite(cls.sensors()).observation_spaces,
                rgb_uuid=cls.EGOCENTRIC_RGB_UUID,
                unshuffled_rgb_uuid=cls.UNSHUFFLED_RGB_UUID,
                cnn_type=cnn_type
                )

        else:
            return ViTRearrangeActorCriticRNN(
                action_space=gym.spaces.Discrete(len(cls.actions())),
                observation_space=kwargs[
                    "sensor_preprocessor_graph"
                ].observation_spaces,
                rgb_uuid=cls.EGOCENTRIC_RGB_RESNET_UUID,
                unshuffled_rgb_uuid=cls.UNSHUFFLED_RGB_RESNET_UUID,
                cnn_type=cnn_type
                )
            # model = PromptViTRearrangeActorCriticRNN(
            #     action_space=gym.spaces.Discrete(len(cls.actions())),
            #     observation_space=kwargs[
            #         "sensor_preprocessor_graph"
            #     ].observation_spaces,
            #     rgb_uuid=cls.EGOCENTRIC_RGB_RESNET_UUID,
            #     unshuffled_rgb_uuid=cls.UNSHUFFLED_RGB_RESNET_UUID,
            #     cnn_type=cnn_type
            #     )

            # name_to_check = "promptpreprocessor._vit.prompt_"
            # for name, param in model.named_parameters():
            #     if name_to_check in name:
            #         if "task" in name:
            #             print(f"turning on {name}")
            #             param.requires_grad_(True)
            #         else:
            #             print(f"turning on {name}")
            #             param.requires_grad_(True)

            # Double check
            # enabled = set()
            # for name, param in model.named_parameters():
            #     if param.requires_grad:
            #         if param.requires_grad == True:
            #             enabled.add((name,param.requires_grad))
            # enabled = sorted(enabled)
            # print(f"Parameters to be updated: {enabled}")
            # exit()
 
            return model
                
