from typing import Callable, Any, Optional, List
import os
import torch
from torch import nn
import torchvision
import torchvision.transforms.functional as TF
from functools import partial

def load_simclr(model, path):
    checkpoint = torch.load(path, map_location='cuda')
    state_dict = checkpoint['state_dict']
    for k in list(state_dict.keys()):
        if k.startswith('backbone.'):
            if k.startswith('backbone') and not k.startswith('backbone.fc'):
                # remove prefix
                state_dict[k[len("backbone."):]] = state_dict[k]
        del state_dict[k]
    
    log = model.load_state_dict(state_dict, strict=False)
    assert log.missing_keys == ['fc.weight', 'fc.bias']
    return model

def load_moco(model, path):
    if os.path.isfile(path):
        checkpoint = torch.load(path, map_location="cuda")
        # rename moco pre-trained keys
        state_dict = checkpoint["state_dict"]
        for k in list(state_dict.keys()):
            # retain only encoder_q up to before the embedding layer
            if k.startswith("module.encoder_q") and not k.startswith(
                "module.encoder_q.fc"
            ):
                # remove prefix
                state_dict[k[len("module.encoder_q.") :]] = state_dict[k]
            # delete renamed or unused k
            del state_dict[k]
        msg = model.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        return model

def load_vc1(model, checkpoint_path):        
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
    if state_dict["pos_embed"].shape != model.pos_embed.shape:
        state_dict["pos_embed"] = resize_pos_embed(
            state_dict["pos_embed"],
            model.pos_embed,
            getattr(model, "num_tokens", 1),
            model.patch_embed.grid_size,
        )

    # filter out keys with name decoder or mask_token
    state_dict = {
        k: v
        for k, v in state_dict.items()
        if "decoder" not in k and "mask_token" not in k
    }

    if model.classifier_feature == "global_pool":
        # remove layer that start with norm
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("norm")}
        # add fc_norm in the state dict from the model
        state_dict["fc_norm.weight"] = model.fc_norm.weight
        state_dict["fc_norm.bias"] = model.fc_norm.bias

    model.load_state_dict(state_dict)
    return model

class SSL_Encoder(nn.Module):
    def __init__(
        self,
        model_name,
        obs_encoding_size,
        num_imgs,
        ssl_path):

        super().__init__()
        self.model_name = model_name

        if model_name == "simclr":
            obs_img_enc_size = int(obs_encoding_size // num_imgs)
            net = torchvision.models.resnet18(
                pretrained=False, num_classes=obs_img_enc_size)
            self.ssl_head = load_simclr(net, ssl_path)
        elif model_name == "moco":
            obs_img_enc_size = int(obs_encoding_size // num_imgs)
            net = torchvision.models.resnet50(
                pretrained=False, num_classes=obs_img_enc_size)
            self.ssl_head = load_moco(net, ssl_path)
        elif model_name == "vc-1":
            obs_img_enc_size = 768
            net = VisionTransformer(
                patch_size=16,
                embed_dim=768,
                depth=12,
                num_heads=12,
                mlp_ratio=4,
                qkv_bias=True,
                norm_layer=partial(nn.LayerNorm, eps=1e-6),
                img_size=224,
                use_cls=True,
                drop_path_rate=0.0
            )
            self.ssl_head = load_vc1(net, ssl_path)
        
        self.mlp = nn.Linear(num_imgs * obs_img_enc_size, obs_encoding_size)
        self.relu = nn.ReLU()
    
    def transform(self, img):
        if self.model_name == "vc-1":
            resize_size = 256
            output_size = 224
            img = TF.resize(
                img,
                size=resize_size,
                interpolation=T.InterpolationMode.BICUBIC
            )
            img = TF.CenterCrop(img, output_size=output_size)
        return img

    def forward(self, img_tensor):
        img_list = torch.split(img_tensor, 3, dim=1)
        encs = []
        for img in img_list:
            transformed_img = self.transform(img)
            encs.append(self.ssl_head(transformed_img))
        concatenated_imgs = torch.cat(encs, dim=-1)
        return self.relu(self.mlp(concatenated_imgs))