# %%
import copy
from functools import partial
import os
from typing import Any, Dict, Optional, Tuple, Union
from einops import rearrange

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from config import AutoConfig
from registry import Registry

from torchvision.models import list_models, get_model
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)

BACKBONES = Registry()

# models = list_models()
# for model in models:
#     # BACKBONES.register(model, module=lambda **kw: get_model(model, **kw))
#     def fn(**kw):
#         return get_model(model, **kw)
#     BACKBONES[model] = fn


def clean_state_dict(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if ".module." in k:
            k = k.replace(".module.", ".")
        new_state_dict[k] = v
    return new_state_dict


@BACKBONES.register("drn_d_22")
def _drn_d_22(pretrained, cache_dir):
    from drn import drn_d_22

    return drn_d_22(pretrained=pretrained)


@BACKBONES.register("drn_c_26")
def _drn_c_26(pretrained, cache_dir):
    from drn import drn_c_26

    return drn_c_26(pretrained=pretrained)


@BACKBONES.register("drn_a_50")
def _drn_a_50(pretrained, cache_dir):
    from drn import drn_a_50

    return drn_a_50(pretrained=pretrained)


@BACKBONES.register("drn_c_58")
def _drn_c_58(pretrained, cache_dir):
    from drn import drn_c_58

    return drn_c_58(pretrained=pretrained)


@BACKBONES.register("drn_d_56")
def _drn_d_56(pretrained, cache_dir):
    from drn import drn_d_56

    return drn_d_56(pretrained=pretrained)


@BACKBONES.register("drn_d_24")
def _drn_d_24(pretrained, cache_dir):
    from drn import drn_d_24

    return drn_d_24(pretrained=pretrained)


def modify_efficientnetv2(model, layers="all"):
    if layers[0] == "all":
        layers = [f"layer{i}" for i in range(len(model.blocks))]
    max_layer = max([int(x.split("layer")[-1]) for x in layers])

    def new_forward(self, x):
        ret = {}
        x = self.stem(x)
        for i, block in enumerate(self.blocks):
            if i > max_layer:
                break
            x = block(x)
            if f"layer{i}" in layers:
                ret[f"layer{i}"] = x
        return ret

    setattr(model.__class__, "forward", new_forward)
    return model


@BACKBONES.register("efficientnet_v2_s_in21k")
def _efin21k_s(pretrained, cache_dir):
    return torch.hub.load(
        "hankyul2/EfficientNetV2-pytorch",
        "efficientnet_v2_s_in21k",
        pretrained=pretrained,
        verbose=False,
    )


@BACKBONES.register("efficientnet_v2_m_in21k")
def _efin21k_m(pretrained, cache_dir):
    return torch.hub.load(
        "hankyul2/EfficientNetV2-pytorch",
        "efficientnet_v2_m_in21k",
        pretrained=pretrained,
        verbose=False,
        cache_dir=cache_dir,
    )


@BACKBONES.register("efficientnet_v2_l_in21k")
def _efin21k_l(pretrained, cache_dir):
    return torch.hub.load(
        "hankyul2/EfficientNetV2-pytorch",
        "efficientnet_v2_l_in21k",
        pretrained=pretrained,
        verbose=False,
    )


@BACKBONES.register("CLIP-RN50")
def _clip_rn50(pretrained, cache_dir):
    import clip

    model, _ = clip.load("RN50", device="cpu")
    return model.visual


@BACKBONES.register("CLIP-RN101")
def _clip_rn101(pretrained, cache_dir):
    import clip

    model, _ = clip.load("RN101", device="cpu")
    return model.visual


@BACKBONES.register("CLIP-RN50x4")
def _clip_rn50x4(pretrained, cache_dir):
    import clip

    model, _ = clip.load("RN50x4", device="cpu")
    return model.visual


@BACKBONES.register("CLIP-RN50x16")
def _clip_rn50x4(pretrained, cache_dir):
    import clip

    model, _ = clip.load("RN50x16", device="cpu")
    return model.visual


@BACKBONES.register("CLIP-RN50x64")
def _clip_rn50x4(pretrained, cache_dir):
    import clip

    model, _ = clip.load("RN50x64", device="cpu")
    return model.visual


@BACKBONES.register("DiNOv2-ViT-B")
def _dino_vit_b(pretrained, cache_dir):
    dinov2_vitb14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
    # dinov2_vitb14 = torch.hub.load('/data/cache/facebookresearch_dinov2_main', 'dinov2_vitb14', source='local')
    dinov2_vitb14 = dinov2_vitb14.cuda()

    return dinov2_vitb14


@BACKBONES.register("DiNOv2-ViT-L")
def _dino_vit_b(pretrained, cache_dir):
    dinov2_vitl14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14")
    dinov2_vitl14 = dinov2_vitl14.cuda()

    return 


def freeze(module):
    from pytorch_lightning.callbacks.finetuning import BaseFinetuning

    BaseFinetuning.freeze(module, train_bn=True)


def modify_sd_unet(unet):
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Dict:
        r"""
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            # logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)

        emb = self.time_embedding(t_emb, timestep_cond)

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError(
                    "class_labels should be provided when num_class_embeds > 0"
                )

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if (
                hasattr(downsample_block, "has_cross_attention")
                and downsample_block.has_cross_attention
            ):
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        if down_block_additional_residuals is not None:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
                down_block_res_sample += down_block_additional_residual
                new_down_block_res_samples += (down_block_res_sample,)

            down_block_res_samples = new_down_block_res_samples

        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
            )

        ret_dict = {}
        ret_dict["unet1"] = down_block_res_samples[2]
        ret_dict["unet2"] = down_block_res_samples[5]
        ret_dict["unet3"] = down_block_res_samples[8]
        ret_dict["unet4"] = sample  # mid_block
        return ret_dict

    setattr(unet.__class__, "forward", forward)
    return unet


@BACKBONES.register("sd1pt5")
class SD1pt5(nn.Module):
    def __init__(self, pretrained=True, cache_dir=None):
        super().__init__()
        from diffusers import StableDiffusionPipeline
        from diffusers import UNet2DConditionModel, AutoencoderKL

        pipe = StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
        )
        pipe.enable_attention_slicing("max")
        pipe.enable_vae_slicing()
        pipe.enable_vae_tiling()
        # pipe.unet.to(memory_format=torch.channels_last)
        pipe.enable_xformers_memory_efficient_attention()
        vae = copy.deepcopy(pipe.vae)
        unet = copy.deepcopy(pipe.unet)
        self.vae_scaling_factor = pipe.vae.config["scaling_factor"]
        self.vae: AutoencoderKL = vae
        unet = modify_sd_unet(unet)
        self.unet: UNet2DConditionModel = unet
        # self.unet.set_attention_slice('auto')
        # del pipe

        self.sd_resolution = (512, 512)

        import clip
        from clip.model import ModifiedResNet

        clip, _ = clip.load("RN50x4", device="cpu")
        clip_visual = copy.deepcopy(clip.visual)
        del clip
        self.clip_visual: ModifiedResNet = clip_visual

        from torchvision.models.feature_extraction import (
            create_feature_extractor,
        )

        nodes = ["layer1", "layer2", "layer3", "layer4", "attnpool"]
        self.fe_clip_visual = create_feature_extractor(
            self.clip_visual,
            return_nodes={name: name for name in nodes},
        )

        self.mlp = nn.Sequential(
            nn.Linear(640, 768),
            nn.SiLU(),
            nn.Linear(768, 768),
        )

        self.clip_resolution = (288, 288)

        freeze(self.vae)
        freeze(self.unet)
        freeze(self.clip_visual)
        freeze(self.fe_clip_visual)

        self.ret_layers = [
            "layer1",
            "layer2",
            "layer3",
            "layer4",
            "unet1",
            "unet2",
            "unet3",
            "unet4",
        ]

    def set_ret_layers(self, layers):
        self.ret_layers = layers

    def downsample(self, x, resolution):
        if x.shape[-2:] == resolution:
            return x
        return F.interpolate(x, resolution, mode="bilinear", align_corners=False)

    def forward(self, x):
        with torch.no_grad():
            x = self.downsample(x, self.sd_resolution)
            latent = self.vae.encode(x)
            sample = latent.latent_dist.sample() * self.vae_scaling_factor

            clip_outs = self.fe_clip_visual(self.downsample(x, self.clip_resolution))

        prompt_embeds = clip_outs["attnpool"]
        clip_outs.pop("attnpool")
        prompt_embeds = self.mlp(prompt_embeds)
        prompt_embeds = rearrange(
            prompt_embeds, "b (s c) -> b s c", s=1
        )  # sequence length 1

        # with torch.no_grad():
        timestamps = 0
        unet_outs = self.unet(sample, timestamps, prompt_embeds)

        out_dict = {}
        out_dict.update(clip_outs)
        out_dict.update(unet_outs)

        out_dict = {k: v for k, v in out_dict.items() if k in self.ret_layers}

        return out_dict


class BackboneNoGradWrap(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        
        self.yes_grad = False

    def forward(self, x):
        if self.yes_grad:
            return self.backbone(x)
        
        with torch.no_grad():
            x = self.backbone(x)
        return x

def build_backbone(
    cfg: AutoConfig,
):
    name = cfg.MODEL.BACKBONE.NAME
    pretrained = cfg.MODEL.BACKBONE.PRETRAINED
    layers = cfg.MODEL.BACKBONE.LAYERS
    cache_dir = cfg.MODEL.BACKBONE.CACHE_DIR

    from filelock import FileLock

    with FileLock(os.path.expanduser("~/.model.lock")):
        torch.hub.set_dir(cache_dir)
        if name in BACKBONES:
            backbone = BACKBONES[name](pretrained, cache_dir)
        else:
            backbone = get_model(name, pretrained=pretrained)

        if name.startswith("sd"):
            backbone.set_ret_layers(layers)
        elif name.startswith("DiNOv2"):
            # layers = [int(x[-1]) for x in layers]
            # backbone = partial(backbone.get_intermediate_layers, n=layers, reshape=True)

            # this is for og dino, reverse order (separate layers)
            # # NOTE: order of layers is reversed
            # def _fast_get_intermediate_layers_not_chunked(self, x, n=1):
            #     if isinstance(n, list):
            #         N = len(self.blocks)
            #         n = [N - i for i in n]
            #     x = self.prepare_tokens_with_masks(x)
            #     # If n is an int, take the n last blocks. If it's a list, take them
            #     output, total_block_len = [], len(self.blocks)
            #     blocks_to_take = (
            #         range(total_block_len - n, total_block_len)
            #         if isinstance(n, int)
            #         else n
            #     )
            #     n_max = max(blocks_to_take)
            #     for i, blk in enumerate(self.blocks):
            #         if i > n_max:
            #             break
            #         x = blk(x)
            #         if i in blocks_to_take:
            #             output.append(x)
            #     assert len(output) == len(
            #         blocks_to_take
            #     ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
            #     return output

            # def _fast_get_intermediate_layers_chunked(self, x, n=1):
            #     if isinstance(n, list):
            #         N = len(self.blocks)
            #         n = [N - i for i in n]
            #     x = self.prepare_tokens_with_masks(x)
            #     output, i, total_block_len = [], 0, len(self.blocks[-1])
            #     # If n is an int, take the n last blocks. If it's a list, take them
            #     blocks_to_take = (
            #         range(total_block_len - n, total_block_len)
            #         if isinstance(n, int)
            #         else n
            #     )
            #     n_max = max(blocks_to_take)
            #     for block_chunk in self.blocks:
            #         if i > n_max:
            #             break
            #         for blk in block_chunk[i:]:  # Passing the nn.Identity()
            #             x = blk(x)
            #             if i in blocks_to_take:
            #                 output.append(x)
            #             i += 1
            #     assert len(output) == len(
            #         blocks_to_take
            #     ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
            #     return output

            # backbone._get_intermediate_layers_not_chunked = partial(
            #     _fast_get_intermediate_layers_not_chunked, backbone
            # )
            # backbone._get_intermediate_layers_chunked = partial(
            #     _fast_get_intermediate_layers_chunked, backbone
            # )

            def dino_new_forward(x, backbone, layers):
                n = [int(i[5:]) for i in layers]
                if len(n) == 1:
                    # speedup
                    n = n[0]
                outs = backbone.get_intermediate_layers(x, n=n, reshape=True)
                out_dict = {k: v for k, v in zip(layers, outs)}
                return out_dict

            backbone = partial(dino_new_forward, backbone=backbone, layers=layers)
        else:
            return_nodes = {x: x for x in layers}
            backbone = create_feature_extractor(backbone, return_nodes=return_nodes)

    if name.startswith("sd"):
        pass
    else:
        backbone = BackboneNoGradWrap(backbone)
        freeze(backbone)

    return backbone


LAYER_DICT = {
    # 'efficientnet_b0': [
    #     'features.5.0',
    #     'features.5.1',
    #     'features.5.2',
    #     'features.6.0',
    #     'features.6.1',
    #     'features.6.2',
    #     'features.6.3',
    #     'features.7.0',
    #     'features.8',
    # ]
    # "efficientnet_b0": [f"features.{i}" for i in range(2, 8)],
    "efficientnet_b0": [
        "features.2.0.block.3",
        "features.2.1.block.3",
        "features.3.0.block.3",
        "features.3.1.block.3",
        "features.4.0.block.3",
        "features.4.1.block.3",
        "features.4.2.block.3",
        "features.5.0.block.3",
        "features.5.1.block.3",
        "features.5.2.block.3",
        "features.6.0.block.3",
        "features.6.1.block.3",
        "features.6.2.block.3",
        "features.6.3.block.3",
        "features.7.0.block.3",
    ],
    "efficientnet_b3": [
        "features.2",
        "features.3",
        "features.4",
        "features.5",
        "features.6",
        "features.7",
    ],
    "efficientnet_v2_s": [
        "features.3",
        "features.4",
        "features.5.4",
        "features.5.8",
        "features.6.7",
        "features.6.14",
    ],
    "efficientnet_v2_m": [
        "features.3",
        "features.4",
        "features.5.7",
        "features.5.13",
        "features.6.8",
        "features.6.17",
        "features.7",
    ],
    "resnet50": [
        "layer1",
        "layer2",
        "layer3",
        "layer4",
    ],
    "resnext101_64x4d": [
        "layer1",
        "layer2",
        "layer3",
        "layer4",
    ],
    "resnet34": [
        "layer1",
        "layer2",
        "layer3",
        "layer4",
    ],
    "resnext50_32x4d": [
        "layer1",
        "layer2",
        "layer3",
        "layer4",
    ],
    "resnext101_32x8d": [
        "layer1",
        "layer2",
        "layer3",
        "layer4",
    ],
}

RESOLUTION_DICT = {
    "resnet50": [224, 224],
    "efficientnet_b0": [224, 224],
    "efficientnet_b3": [300, 300],
    "efficientnet_v2_s": [384, 384],
    "efficientnet_v2_m": [480, 480],
    "resnext101_64x4d": [224, 224],
}
# %%
if __name__ == "__main__":
    sd = SD1pt5(pretrained=True, cache_dir=None)
    out_dict = sd(torch.randn(1, 3, 512, 512))
    for k, v in out_dict.items():
        print(k, v.shape)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(count_parameters(sd.vae))
    print(count_parameters(sd.unet))
    print(count_parameters(sd.clip_visual))

# %%
if __name__ == "__main__":
    from torchvision.models import get_model

    b = "resnext101_64x4d"
    model = get_model(b, pretrained=False)
    nodes, _ = get_graph_node_names(model)
    print(list(filter(lambda x: "relu_2" in x, nodes)))
    # print('\n'.join(nodes))
    r = RESOLUTION_DICT[b]
    model = create_feature_extractor(model, return_nodes=nodes)
    out = model(torch.randn(1, 3, *r))
    for k, v in out.items():
        print(k, "\t\t\t\t", v.shape)


# %%
if __name__ == "__main__":
    # model = BACKBONES['swin_v2_b'](pretrained=True, cache_dir=None)
    from config_utils import get_cfg_defaults

    from torchvision.models import efficientnet_v2_s, ResNet, swin_s

    cfg = get_cfg_defaults()
    # cfg.MODEL.BACKBONE.NAME = "efficientnet_v2_s"
    cfg.MODEL.BACKBONE.PRETRAINED = True
    cfg.MODEL.BACKBONE.LAYERS = [f"features.{i}" for i in range(0, 8)]
    cfg.MODEL.BACKBONE.NAME = "efficientnet_b5"
    # cfg.MODEL.BACKBONE.LAYERS = ["layer1", "layer2", "layer3", "layer4"]
    backbone = build_backbone(cfg)

    out = backbone(torch.randn(1, 3, 224, 224))
    for k, v in out.items():
        print(k, v.shape)

# %%
if __name__ == "__main__":
    from config_utils import get_cfg_defaults

    backbone = _efin21k_l(pretrained=True, cache_dir=None)
    nodes1, nodes2 = get_graph_node_names(backbone)
    print("nodes1")
    print(nodes1)
    print("nodes2")
    print(nodes2)
# %%
