import argparse
import os
from datetime import datetime
import torch
import torch.nn.functional as F
import wandb
from utils.config_parser import ConfigParser
import numpy as np
from datasets import Dataset, DatasetDict, load_dataset
import pandas as pd
import random
from torch.utils.data import DataLoader
from transformers import set_seed, CLIPProcessor, CLIPModel
from torch.nn import KLDivLoss
from utils.hf_captions import create_hf_coco_dataset
from utils.misc import fix_seed
from const import *


# def clip_inference(clip_model, clip_processor, caption_dataset_batch, type="image"):
#     text = [ct[0] for ct in caption_dataset_batch["captions"]]
#     images = caption_dataset_batch["image"]
#     inputs = clip_processor(text=text, images=images, return_tensors="pt", padding=True).to(clip_model.device)
#     with torch.no_grad():
#         outputs = clip_model(**inputs, return_dict=True)
#     return outputs.logits_per_image, outputs.image_embeds


def clip_inference(clip_model, clip_processor, caption_dataset_batch, mode="image"):
    """
    Compute similarity and embeddings for either image-text or text-text pairs using CLIP.

    Args:
        clip_model: a HuggingFace CLIPModel instance
        clip_processor: a HuggingFace CLIPProcessor instance
        caption_dataset_batch: dict with keys
            - "captions": a list of (text,) tuples
            - "image": a list of image tensors (only used in image mode)
        mode: "image" to compute logits_per_image & image_embeds (original behavior),
              "text"  to compute logits_per_text & text_embeds

    Returns:
        logits: similarity matrix (image-to-text or text-to-text)
        embeds: embeddings (image_embeds or text_embeds)
    """
    texts = [ct[0] for ct in caption_dataset_batch["captions"]]

    if mode == "image":
        images = caption_dataset_batch["image"]
        inputs = clip_processor(text=texts, images=images, return_tensors="pt", padding=True).to(clip_model.device)
        with torch.no_grad():
            outputs = clip_model(**inputs, return_dict=True)
        return outputs.logits_per_image, outputs.image_embeds

    elif mode == "text":
        # Tokenize only the texts
        inputs = clip_processor(text=texts, return_tensors="pt", padding=True).to(clip_model.device)
        with torch.no_grad():
            # Get text embeddings
            text_embeds = clip_model.get_text_features(**inputs)
        # Normalize embeddings to unit length
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
        # Compute pairwise text-to-text similarity matrix
        logit_scale = clip_model.logit_scale.exp().to(clip_model.device)
        logits_per_text = torch.matmul(text_embeds, text_embeds.t()) * logit_scale 
        
        return logits_per_text, text_embeds

    else:
        raise ValueError(f"Unknown mode '{mode}'. Use 'image' or 'text'.")


def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_num_threads(16)

    # --- Load Caption Dataset for Regularization ---
    kl_batch_size = 128
    kl_control_size = 1024
    caption_dataset = create_hf_coco_dataset(CAPTION_FILE_PATH, IMAGE_FOLDER_PATH).select(range(kl_control_size))
    
    clip_model = "openai/clip-vit-large-patch14"
    clip_processor = CLIPProcessor.from_pretrained(clip_model)
    clip_model = CLIPModel.from_pretrained(clip_model).to(device)
    for param in clip_model.parameters():
        param.requires_grad = False
    
    reference_logits, image_embeds = clip_inference(clip_model, clip_processor, caption_dataset, mode="text")
    print(f"reference_logits shape: {reference_logits.shape}") # [1000, 1000]
    print(f"image_embeds shape: {image_embeds.shape}")


if __name__ == "__main__":
    # main()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # clip_model = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    clip_model = "openai/clip-vit-large-patch14"
    # clip_model = "laion/CLIP-ViT-g-14-laion2B-s34B-b88K"
    clip_processor = CLIPProcessor.from_pretrained(clip_model)
    clip_model = CLIPModel.from_pretrained(clip_model).to(device)
    for param in clip_model.parameters():
        param.requires_grad = False
    # clip_model.text_model.encoder.layers[0].self_attn.k_proj.weight
    
    from transformers import CLIPTextModel, CLIPTokenizer
    from diffusers import StableDiffusionPipeline, DiffusionPipeline
    # pipe3 = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
    # pipe2 = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
    pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        # torch_dtype=torch.float16, 
        use_safetensors=True, 
        # variant="fp16"
    )
    text_encoder = pipe.text_encoder
    tokenizer = pipe.tokenizer
    text_encoder2 = pipe.text_encoder_2

    kl_control_size = 128
    caption_dataset = create_hf_coco_dataset(CAPTION_FILE_PATH, IMAGE_FOLDER_PATH).select(range(kl_control_size))

    device = text_encoder.device
    texts = [ct[0] for ct in caption_dataset["captions"]]
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
    text_embeds = text_encoder(**inputs)[1]
    import pdb; pdb.set_trace()
    # text_embeds = clip_model.text_projection(text_embeds) # [1000, 768]
    text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # [1000, 768]
    # logit_scale = clip_model.logit_scale.exp() # 100.0
    logit_scale = 100.0
    logits_per_text = torch.matmul(text_embeds, text_embeds.t()) * logit_scale # [1000, 1000]
    logits_per_image = logits_per_text.T
    import pdb; pdb.set_trace()
    

            # text_embeds = []
            # for j in range(int(np.ceil(kl_control_size/kl_batch_size))):
            #     batch = caption_dataset[j*kl_batch_size:(j+1)*kl_batch_size]
            #     ts = [ct[0] for ct in batch["captions"]]
            #     inputs = tokenizer(
            #         ts, return_tensors="pt", padding=True, truncation=True
            #     ).to(device)
            #     aa = encoder_policy(**inputs).text_embeds
            #     import pdb; pdb.set_trace()
            #     text_output = encoder_policy(**inputs)[1]
            #     text_embeds.append(text_output)
            # text_embeds = torch.cat(text_embeds, dim=0)