import jax
import jax.numpy as jnp
import torch
import re
import functools

from .vit import DinoViT

def dino_config_by_size(dino_size):
    if dino_size == 's':
        num_heads = 6
        embed_dim = 384
        depth = 12
    elif dino_size == 'b':
        num_heads = 12
        embed_dim = 768
        depth = 12
    elif dino_size == 'l':
        num_heads = 16
        embed_dim = 1024
        depth = 24
    return num_heads, embed_dim, depth



def load_vit_params(params_jax: dict, vit_pt: torch.nn.Module, dino_size='s'):
    print('load dino weights')

    if dino_size not in ['s', 'b', 'l']:
        print('dino_size should be one of s, b, l')
        return params_jax

    if vit_pt is None:
        vit_pt = torch.hub.load("facebookresearch/dinov2", f"dinov2_vit{dino_size}14_reg")
    jax_params_flat, jax_param_pytree = jax.tree_util.tree_flatten_with_path(params_jax)
    dinov2_params = {path: param for path, param in vit_pt.named_parameters()}

    no_transpose = {
        "cls_token",
        "pos_embed",
        "mask_token",
        "register_tokens",
    }
    dinov2_params_flat = []
    for path, param in jax_params_flat:
        shape = param.shape
        path = ".".join([p.key for p in path if p.key!='DinoViT_0'])
        path = re.sub(r"\.scale|.kernel", ".weight", path)
        if path in dinov2_params:
            dinov2_param = dinov2_params[path]
            if path not in no_transpose:
                if len(shape) == 4:
                    dinov2_param = torch.permute(dinov2_param, (2, 3, 1, 0))
                else:
                    dinov2_param = torch.permute(
                        dinov2_param, tuple(reversed(range(len(shape))))
                    )
            if shape != dinov2_param.shape:
                print(path, shape, dinov2_params[path])
            dinov2_params_flat.append(jnp.asarray(dinov2_param.detach().numpy()))
            dinov2_params.pop(path)
        else:
            print(path, shape, None)
            # dinov2_params_flat.append(None)
            dinov2_params_flat.append(param)
    # for path, param in dinov2_params.items():
    #     print(path, None, param.shape)

    return jax.tree_util.tree_unflatten(jax_param_pytree, dinov2_params_flat)


def load_dino_vits(img_size=[518,518], dino_size='s'):
    mlp_ratio = 4

    num_heads, embed_dim, depth = dino_config_by_size(dino_size)

    vit_cls = functools.partial(
        DinoViT,
        num_heads=num_heads,
        embed_dim=embed_dim,
        mlp_ratio=mlp_ratio,
        depth=depth,
        img_size=img_size,
        register_tokens=True,
    )
    vit_def = vit_cls()
    vit_params = vit_def.init(jax.random.PRNGKey(0), jnp.ones((1, img_size[0], img_size[1], 3)))[
        "params"
    ]

    # dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg")

    # params = load_vit_params(vit_params, dinov2_vits14)
    params = load_vit_params(vit_params, None, dino_size=dino_size)

    return (vit_def, {"params": params})


def test_dino_vits(dino_size='s'):
    import numpy as onp

    image = jax.random.uniform(jax.random.PRNGKey(0), (1, 70, 70, 3))
    jax_vit_def, jax_params = load_dino_vits([70, 70], dino_size=dino_size)

    # JAX: forward pass
    image = jax.random.uniform(jax.random.PRNGKey(0), (1, 70, 70, 3))
    embed_jax = jax_vit_def.apply(jax_params, image, train=False)
    embed_jax = onp.asarray(embed_jax["x_norm_patchtokens"])

    # Torch: forward pass
    image_torch = torch.from_numpy(onp.asarray(image.transpose((0, 3, 1, 2)))).cuda()
    dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", f"dinov2_vit{dino_size}14_reg").cuda()
    dinov2_vits14 = dinov2_vits14.cuda()
    dinov2_vits14.eval()
    embed_torch = (
        dinov2_vits14.forward_features(image_torch)["x_norm_patchtokens"]
        .detach()
        .cpu()
        .numpy()
    )
    embed_torch2 = (
        dinov2_vits14.forward_features(torch.rand((1, 3, 70, 70), device="cuda"))[
            "x_norm_patchtokens"
        ]
        .detach()
        .cpu()
        .numpy()
    )

    cosine_distance = (
        onp.sum(embed_torch * embed_jax)
        / onp.linalg.norm(embed_torch)
        / onp.linalg.norm(embed_jax)
    )
    cosine_distance2 = (
        onp.sum(embed_torch2 * embed_jax)
        / onp.linalg.norm(embed_torch2)
        / onp.linalg.norm(embed_jax)
    )

    # Cosine distance for the first pair (same image) should be close to 1
    assert cosine_distance > 0.999, cosine_distance
    # Cosine distance for the second pair (different images) should be further away.
    # It might still be close to 1, because random noise is semantically similar.
    assert cosine_distance2 < 0.95, cosine_distance2

if __name__ == '__main__':
    test_dino_vits()