from torch import nn
from transformers import ViTModel, ViTConfig
import torch
import torch.nn.functional as F
import numpy as np
from torchvision.transforms import ToPILImage

class PretrainedViTForReconstructionNoPreprocessing(nn.Module):
    """
    Pretrained ViT for reconstruction without the preprocessing step.
    Works directly with flattened patches in the same format as DUST models.
    
    Adapts to the pretrained ViT architecture while maintaining compatible 
    properties with other models in model_reconstruction.py.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        
        self.num_layers = len(self.vit.encoder.layer)
        self.num_hidden = self.vit.config.hidden_size
        self.num_channels = self.vit.config.num_channels
        self.out_patch_dim = self.vit.config.patch_size*self.vit.config.patch_size
        
        self.config["num_layers"] = self.num_layers
        self.config["num_hidden"] = self.num_hidden
        
        self.patch_size = self.vit.config.patch_size   # refers to vit patching here but usually coincides with config["tile_size"]
        patch_dim = self.patch_size * self.patch_size
        self.input_proj = nn.Linear(patch_dim, self.vit.config.hidden_size)
        self.output_proj = nn.Linear(self.vit.config.hidden_size, patch_dim)
        
        self.sparse_code = None
        self.reconstructions = None # For efficiency purposes, store the latest reconstructions for each layer.

        self.reconstruction_size=self.patch_size*self.patch_size*self.num_channels
        self.D = torch.eye(self.num_hidden, self.reconstruction_size).unsqueeze(0).repeat(self.num_layers,1,1)
        self.A = nn.Parameter(torch.empty(self.num_hidden, config["num_features"]))
        
        # unused
        self.lambda1 = nn.Parameter(torch.tensor([0.01]))
        self.lambda2 = nn.Parameter(torch.tensor([0.01]))
        self.h_0 = nn.Parameter(torch.zeros(self.num_hidden))
        self.alpha = nn.Parameter(torch.tensor([1.0]))
        self.sparsity = torch.zeros(self.num_layers)
        
        print(f"Initialized ViT with {self.num_layers} transformer layers and hidden size {self.num_hidden}")
        
    def forward(self, x):
        """
        Forward pass of the ViT model for reconstruction.
        
        Args:
            x: Input tensor of shape (T, B, W, H)
            
        Returns:
            Reconstructed output tensor of shape (T, B, W, H)
        """
        T, B, C, W, H = x.shape  # the Dataloader prepares it to get (T,B,3,224,224)

        x = x.reshape(T*B, C, W, H)
        out = self.vit(x, output_hidden_states=True)
        
        hidden_states = torch.stack(out.hidden_states[1:])[:,:,1:,:]  # Skip embedding layer, drop CLS token
        
        # Reshape hidden states to reconstruct images for all layers at once
        patches_2d_stack = hidden_states.reshape(self.num_layers, T*B, 14, 14, self.num_hidden)
        patches_2d_stack = patches_2d_stack.reshape(self.num_layers, T*B, 14, 14, self.patch_size, self.patch_size, C)
        reconstructed_stack = patches_2d_stack.permute(0, 1, 6, 2, 4, 3, 5).reshape(self.num_layers, T*B, C, W, H)
        reconstructed_stack = reconstructed_stack.reshape(self.num_layers, T, B, C, W, H)
        

        self.reconstructions = reconstructed_stack  # (L,T,B,C,W,H) -> swap to (T,L,B,C,W,H)
        self.reconstructions = self.reconstructions.permute(1, 0, 2, 3, 4, 5)
        

        self.sparse_code = hidden_states.reshape(self.num_layers, T*B, -1).permute(1, 0, 2)  # (T*B,L,F)
        self.sparse_code = self.sparse_code.reshape(T, B, self.num_layers, -1).permute(0, 2, 1, 3)  # (T,L,B,F)
        
        # Return the reconstruction from the final layer
        return self.reconstructions[:, -1]  # (T,B,C,W,H)

class ViTClipProcessor:
    def __init__(self, processor:'transformers.ViTImageProcessor',device='cpu'):
        self.processor = processor
        self.device = device
        self.tensor_to_pil = ToPILImage()
                        
    def __call__(self, clip_frames:'np.array') -> 'np.array':
        """expects clips pre swapping (shape B,T,...).
        Returns an array of shape (1,T,3,224,224) """
        processed_frames = []
        for frame in clip_frames:
            frame_tensor = torch.tensor(frame,device=self.device)
            if frame_tensor.ndim == 2:
                frame_tensor = frame_tensor.unsqueeze(0)

            # Convert the frame tensor to a PIL and into RGB (from grayscale)
            frame_tensor = self.tensor_to_pil(frame_tensor).convert("RGB")
            
            processed_frame = self.processor(
                images=frame_tensor, 
                return_tensors="pt")
            processed_frame = processed_frame['pixel_values'].numpy()

            processed_frames.append(processed_frame)
        
        processed_frames = np.stack(processed_frames, axis=1)
        return processed_frames
    
    def invert_processor(self, processed_tensor):
        """
        Inverts the preprocessing steps applied by the ViT image processor.
        
        Args:
            processed_tensor: Tensor with shape (batch, channels, height, width), e.g. (1, 3, 224, 224)
            processor: ViTImageProcessor instance
            
        Returns:
            Tensor descaled, de-meaned
        """
        # Get mean and std from the processor
        std = torch.tensor(self.processor.image_std,device=processed_tensor.device).view(1, -1, 1, 1)
        mean = torch.tensor(self.processor.image_mean,device=processed_tensor.device).view(1, -1, 1, 1)
        
        # Reverse the normalization: x = (normalized * std) + mean
        x = processed_tensor * std + mean

        # If the processor did rescaling (dividing by 255), reverse that:
        if getattr(self.processor, "do_rescale", False):
            x = x * 255.0

        # Clip values to valid range and convert to uint8
        x = x.clamp(0, 255)#.to(torch.uint8)

        return x