import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import timm
import model.vision_transformer as vits


class DINO(nn.Module):
    def __init__(self, args):
        super().__init__()

        # === DINO layers ===
        self.dino = self.load_dino(args)
        self.dinov2 = args.dinov2
        self.supervised = args.supervised
        self.resize_to = args.resize_to

        # === Token attributes ===
        self.feat_res = [size // args.patch_size for size in args.resize_to]
        self.token_drop_ratio = args.token_drop_ratio

        self.token_num = int(self.feat_res[0] * self.feat_res[1])
        self.reduced_token_num = int(self.token_num * (1.0 - self.token_drop_ratio))

    def load_dino(self, args):
        assert args.resize_to[0] % args.patch_size == 0
        assert args.resize_to[1] % args.patch_size == 0
        
        if args.arch == "vit_base" and args.patch_size == 8:
            model = torch.hub.load("facebookresearch/dino:main", "dino_vitb8")
        elif args.arch == "vit_base" and args.patch_size == 16 and not args.supervised:
            model = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
        elif args.arch == "vit_base" and args.patch_size == 14 and args.dinov2:
            model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
        elif args.arch == "vit_base" and args.patch_size == 16 and args.supervised:
            model = timm.create_model("vit_base_patch16_224", pretrained=True, img_size=(args.resize_to[0], args.resize_to[1]))
        else:
            assert False

        for p in model.parameters():
            p.requires_grad = False

        # wget https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth
        # wget https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth
        # wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth
        
        return model

    @torch.no_grad()
    def forward(self, x, reduce=True):
        # :arg x:  (B, F, 3, H, W)
        # 
        # :return x:  (B, token * token_drop_ratio, 768)
        # :return rand_indices:  (B, token * token_drop_ratio)

        x = torch.flatten(x, start_dim=0, end_dim=1)
        if self.dinov2:
            x = self.dino.prepare_tokens_with_masks(x)
        elif self.supervised:
            x = self.dino.patch_embed(x)
            x = self.dino._pos_embed(x)
        else:
            x = self.dino.prepare_tokens(x)

        B = x.shape[0]

        if self.token_drop_ratio != 0.0 and reduce:
            rand_indices = [torch.randperm(self.token_num, device=x.device)[:self.reduced_token_num] for _ in range(B)]
            rand_indices = torch.vstack(rand_indices).sort()[0] + 1
            # add cls token
            rand_indices = torch.cat([torch.zeros(B, 1, device=x.device, dtype=torch.long), rand_indices], dim=-1)  # (B, N' + 1)
            
            x = torch.gather(x, dim=1, index=rand_indices.unsqueeze(dim=-1).repeat(1, 1, 768))  # (B, N' + 1, 768)
            rand_indices = rand_indices[:, 1:] - 1                                              # (B, N')
        else:
            rand_indices = [torch.randperm(self.token_num, device=x.device) for _ in range(B)]
            rand_indices = torch.vstack(rand_indices).sort()[0]

        for blk in self.dino.blocks:
            x = blk(x)
        x = x[:, 1:]

        assert x.shape[0] == B
        if reduce: 
            assert x.shape[1] == self.reduced_token_num
        else: 
            assert x.shape[1] == self.token_num
        assert x.shape[2] == 768

        return x, rand_indices

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, residual=False, layer_order="none"):
        super().__init__()
        self.residual = residual
        self.layer_order = layer_order
        if residual:
            assert input_dim == output_dim

        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)
        self.activation = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=0.1)

        if layer_order in ["pre", "post"]:
            self.norm = nn.LayerNorm(input_dim)
        else:
            assert layer_order == "none"

    def forward(self, x):
        input = x

        if self.layer_order == "pre":
            x = self.norm(x)

        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        x = self.dropout(x)

        if self.residual:
            x = x + input
        if self.layer_order == "post":
            x = self.norm(x)

        return x