import itertools
import math
import os
import typing
from dataclasses import dataclass
from pathlib import Path

import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import transformers
import torchvision.models as vision_models

import dataloader
import models
import noise_schedule
import utils
import logging
import hashlib
from typing import Dict, Tuple, Iterable

from logging.handlers import RotatingFileHandler

from transformers import CLIPVisionModel, CLIPProcessor

import csv
import matplotlib.pyplot as plt
from PIL import Image as PILImage
import json
from datetime import datetime

import torchvision.transforms as transforms

### ==== IMAGE ENCODER ====

class ImageEncoder(nn.Module):

    # Image encoder for extracting raw features
    def __init__(self, config_main):
        super().__init__()
        self.config = config_main.image_encoder

        if self.config.type == "clip":
            self.use_patch_tokens = self.config.get("use_patch_tokens", False)
            self.encoder = CLIPVisionModel.from_pretrained(self.config.custom_weight_path)
            self.processor = CLIPProcessor.from_pretrained(self.config.custom_weight_path)
            self.output_dim = self.encoder.config.hidden_size
            self.patch_projection = None 

        else:
            raise ValueError(f"Unsupported image encoder type: {self.config.type}")

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """Extract features from images and project to model dimension.

        Args:
            images: Tensor of shape [batch_size, channels, height, width]
                    For CLIP, ensure CLIP normalization & resolution in your dataloader.

        Returns:
            image_features:
              - If CLIP with patch tokens: [batch_size, N, hidden]
              - Else: [batch_size, hidden]
        """
        if self.config.type == "clip":
            outputs = self.encoder(pixel_values=images)
            if getattr(self, "use_patch_tokens", False):
                patch_tokens = outputs.last_hidden_state  # [B, 1+N, D]
                return patch_tokens[:, 1:, :]  # drop CLS token
            else:
                return outputs.pooler_output  # [B, D]
