import torch
import torch.nn as nn
import sys
import os
path_to_tokenfd = '' 
if os.path.isdir(path_to_tokenfd) and path_to_tokenfd not in sys.path:
    print(f"Adding {path_to_tokenfd} to system path")
    sys.path.append(path_to_tokenfd)
from internvl.model.internvl_chat import InternVisionModel

class TokenFDViT(nn.Module):
    def __init__(self, checkpoint_path, torch_dtype=torch.bfloat16):
        super().__init__()
        print(f"Loading TokenFD visual backbone from: {checkpoint_path}")
        self.vision_encoder = InternVisionModel.from_pretrained(
            checkpoint_path,
            low_cpu_mem_usage=True,
            torch_dtype=torch_dtype
        )
        self.output_dim = self.vision_encoder.config.hidden_size  # 1024
        self.num_patches = (self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2

    def forward(self, images: torch.Tensor):
        outputs = self.vision_encoder(pixel_values=images)
        last_hidden_state = outputs[0]
        patch_tokens = last_hidden_state[:, 1:]
        assert patch_tokens.shape[1] == self.num_patches, "The number of patch tokens does not match the expected value."
        
        return patch_tokens