import time
import torch.nn as nn
from models.model_utils import forward_with_checkpoint, set_attributes_from_args, imagenet_mean, imagenet_std, hash_tensor
from models.mlp import MLP
import warnings
import torch
import numpy as np
import cv2
from constants import DINO_SIZE
from torchvision import transforms
from torchvision.utils import save_image
from models.model_wrapper import ModelWrapper
from nn_util import set_seed

default_transform = transforms.Compose([
                        transforms.Resize((224, 224), antialias=False),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])])

class DINO(ModelWrapper):
    def __init__(self, wrapped: nn.Module, **kwargs):
        DEFAULT_DINO_CONFIG = {
            'input_len': None,
            'device': None,
            'rgb_height': 224,
            'rgb_width': 224,
            'grayscale': False,
            'optimizer': 0,
        }

        super(DINO, self).__init__()
        set_attributes_from_args(self, DEFAULT_DINO_CONFIG, kwargs)
        assert isinstance(self.rgb_height, int) and isinstance(self.rgb_width, int), "Image dimensions must be integers!"

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="xFormers is not available*") 
            self.dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', verbose=False).to(self.device)

        self.wrapped = wrapped

        #self.input_len = self.rgb_height * self.rgb_width * 3
        self.output_len = wrapped.output_len

        for param in self.dino.parameters():
            param.optimizer = self.optimizer

    def forward(self, input):
        return self.wrapped(self.frames_to_dino(input))

    # Assumes frames dimensions [B, H, W, C] or batch-less [H, W, C]
    #@profile
    def frames_to_dino(self, frames: np.ndarray | torch.Tensor) -> torch.Tensor:
        # import sys
        # np.set_printoptions(threshold=sys.maxsize)
        #print(np.sum(frames[:3].reshape(224, 224 * 3 * 3), axis=1))
        #set_seed(42)

        # Bad hack - wish I could use self.device, but this isn't stored properly when compiled for whatever reason
        device = next(self.dino.parameters()).device
        with torch.set_grad_enabled(self.dino.training):
            batch_size = 2 ** 9
            #batch_size = 1
            batches = np.ceil(len(frames) / batch_size).astype(np.uint64)
            dino_features = torch.empty((len(frames), DINO_SIZE), device=device)

            for i in range(batches):
                #print(f"DINO batch {i}")
                start = i * batch_size
                end = (i + 1) * batch_size
                curr_frames = frames[start:end]

                # If numpy array, turn into tensor and put on device
                if isinstance(curr_frames, np.ndarray):
                    curr_frames = torch.as_tensor(curr_frames, device=device)
                #print(f"Frame hash: {hash_tensor(curr_frames)}")

                assert isinstance(curr_frames, torch.Tensor)

                # Needs to be reshaped: [H * W * C]
                if curr_frames.dim() == 1:
                    curr_frames = curr_frames.view(1, self.rgb_height, self.rgb_width, 1 if self.grayscale else 3)

                # Needs to be reshaped: [B, H * W * C]
                if curr_frames.dim() == 2:
                    curr_frames = curr_frames.view(curr_frames.shape[0], self.rgb_height, self.rgb_width, 1 if self.grayscale else 3)

                # Add batch dimension if necessary
                if curr_frames.dim() == 3:
                    curr_frames = curr_frames.unsqueeze(0)

                if self.grayscale:
                    # Grayscale -> RGB
                    curr_frames = curr_frames.repeat(1, 1, 1, 3)

                # [B, H, W, C] -> [B, C, H, W]
                # [0, 255] -> [0, 1]
                curr_frames = (curr_frames.permute(0, 3, 1, 2) / 255.0).to(torch.float32)

                curr_frames = default_transform(curr_frames)

                #print(f"Frame hash: {hash_tensor(curr_frames[0])}")

                # Extract features
                if self.dino.training:
                    dino_features[start:end] = forward_with_checkpoint(self.dino, curr_frames)
                else:
                    #print(curr_frames.shape)
                    #save_image(curr_frames[0].cpu(), f'test_dino_{int(time.time() * 10**3)}.png')
                    #save_image(curr_frames[1].cpu(), f'test_dino_1.png')
                    #save_image(curr_frames[2].cpu(), f'test_dino_2.png')
                    dino_features[start:end] = self.dino(curr_frames)
                #print(dino_features[0, -1].item())
                #print(f"Feature hash: {hash_tensor(dino_features[start:end])}")
                #import pickle
                #pickle.dump(dino_features[0].detach().cpu(), open("debug_dino_test.pkl", 'wb'))
                # for params in self.dino.parameters():
                #     print(hash_tensor(params))
                #torch.cuda.empty_cache()

            return dino_features
