import torch.nn as nn
import torch.nn.functional as F
from constants import RESNET_SIZE
from models.model_utils import set_attributes_from_args, imagenet_mean, imagenet_std, hash_tensor
import numpy as np
import torch
import cv2
from torchvision import transforms
from torchvision.utils import save_image

from models.model_wrapper import ModelWrapper

default_transform = transforms.Compose([
                        transforms.Resize((224, 224), antialias=False),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])])

class ResNet(ModelWrapper):
    def __init__(self, wrapped: nn.Module, **kwargs):
        DEFAULT_RESNET_CONFIG = {
            'device': None,
            'rgb_height': 224,
            'rgb_width': 224,
            'grayscale': False,
            'pretrained': True
        }

        super(ResNet, self).__init__()
        set_attributes_from_args(self, DEFAULT_RESNET_CONFIG, kwargs)

        self.resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=self.pretrained).to(self.device)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])

        self.wrapped = wrapped
        self.input_len = self.rgb_height * self.rgb_width * 3
        self.output_len = wrapped.output_len

    def forward(self, input):
        return self.wrapped(self.frames_to_resnet(input))

    # Assumes frames dimensions [B, H, W, C] or batch-less [H, W, C]
    def frames_to_resnet(self, frames: np.ndarray | torch.Tensor) -> torch.Tensor:
        #import sys
        #np.set_printoptions(threshold=sys.maxsize)
        #print(frames)
        #print(np.sum(frames[0].reshape(84 * 4, 84), axis=1))
        #print(np.sum(frames[0].reshape(84 * 4, 84), axis=0))
        with torch.set_grad_enabled(self.resnet.training):
            batch_size = 1
            #batch_size = 2 ** 8
            batches = np.ceil(len(frames) / batch_size).astype(np.uint64)
            resnet_features = torch.empty((len(frames), RESNET_SIZE), device=self.device)

            for i in range(batches):
                start = i * batch_size
                end = (i + 1) * batch_size
                curr_frames = frames[start:end]

                # If numpy array, turn into tensor and put on device
                if isinstance(curr_frames, np.ndarray):
                    curr_frames = torch.as_tensor(curr_frames, device=self.device)
                #print(f"Frame hash: {hash_tensor(curr_frames)}")

                assert isinstance(curr_frames, torch.Tensor)

                # Needs to be reshaped: [H * W * C]
                if curr_frames.dim() == 1:
                    curr_frames = curr_frames.view(1, self.rgb_height, self.rgb_width, 1 if self.grayscale else 3)

                # Needs to be reshaped: [B, H * W * C]
                if curr_frames.dim() == 2:
                    curr_frames = curr_frames.view(curr_frames.shape[0], self.rgb_height, self.rgb_width, 1 if self.grayscale else 3)

                # Add batch dimension if necessary
                if curr_frames.dim() == 3:
                    curr_frames = curr_frames.unsqueeze(0)

                if self.grayscale:
                    # Grayscale -> RGB
                    curr_frames = curr_frames.repeat(1, 1, 1, 3)

                # [B, H, W, C] -> [B, C, H, W]
                # [0, 255] -> [0, 1]
                curr_frames = (curr_frames.permute(0, 3, 1, 2) / 255.0).to(torch.float32)

                curr_frames = default_transform(curr_frames)

                #save_image(curr_frames[0].cpu(), f'test_atari.png')

                # Extract features
                resnet_features[start:end] = self.resnet(curr_frames).squeeze(2).squeeze(2)
                #print(f"Feature hash: {hash_tensor(resnet_features[start:end])}")
                torch.cuda.empty_cache()

            return resnet_features
