import os
import argparse
import json
import numpy as np
import torch
from PIL import Image
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from torchvision import datasets
import copy

def main():
    parser = argparse.ArgumentParser(description="Calculate hidden states from an MLLM for generated images.")
    parser.add_argument('--model_id', type=str, default='llava-hf/llama3-llava-next-8b-hf', help='The model ID for the MLLM.')
    parser.add_argument('--image_dir', type=str, default='./new_diffusion_images', help='Directory with the generated diffusion images.')
    parser.add_argument('--output_file', type=str, default='./text_and_embedding.json', help='Output JSON file for hidden states and captions.')
    parser.add_argument('--num_iterations', type=int, default=4, help='Number of sampling iterations per style.')
    parser.add_argument('--images_per_iteration', type=int, default=3, help='Number of images to sample per iteration.')
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    processor = LlavaNextProcessor.from_pretrained(args.model_id)
    model = LlavaNextForConditionalGeneration.from_pretrained(
        args.model_id, 
        torch_dtype=torch.float16, 
        low_cpu_mem_usage=True,
        attn_implementation='flash_attention_2',
        device_map='auto'
    ).eval()

    train_dataset = datasets.ImageFolder(root=args.image_dir, transform=None)
    all_paths = np.array(train_dataset.samples)[:, 0]
    domains = os.listdir(args.image_dir)

    total_image_caption = {}

    with torch.inference_mode():
        for domain in domains:
            print(f"Processing domain: {domain}")
            total_image_caption[domain] = {}
            cur_all_paths = [path for path in all_paths if f"{args.image_dir}/{domain}/" in path]

            if not cur_all_paths:
                continue

            for _ in range(args.num_iterations):
                cur_paths = np.random.choice(cur_all_paths, size=min(args.images_per_iteration, len(cur_all_paths)), replace=False)
                
                images = []
                for cur_path in cur_paths:
                    cur_image = Image.open(cur_path).resize((300, 300))
                    if cur_image.mode != 'RGB':
                        cur_image = cur_image.convert('RGB')
                    images.append(cur_image)

                conversation = [{
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Describe the aspects of the style that applies regardless of category. Provide a description. Do not describe the object in the image, but the style of image. Be as detailed, complete, and comprehensive as possible. Explain every minute detail."},
                        *[copy.deepcopy({"type": "image"}) for _ in range(len(images))],
                    ],
                }]

                prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
                inputs = processor(images=images, text=prompt, return_tensors="pt").to(model.device)

                # Get hidden states
                outputs = model(**inputs, output_hidden_states=True, return_dict=True)
                last_hidden_states = outputs.hidden_states[-1].mean(dim=1).detach().cpu()

                # Generate text
                generate_ids = model.generate(**inputs, max_new_tokens=150)
                outputs_text = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
                cleaned_text = outputs_text.split('assistant')[-1].strip()

                if 'last' not in total_image_caption[domain]:
                    total_image_caption[domain]['last'] = []
                    total_image_caption[domain]['text'] = []

                total_image_caption[domain]['last'].append(last_hidden_states.tolist())
                total_image_caption[domain]['text'].append(cleaned_text)

    with open(args.output_file, "w") as f:
        json.dump(total_image_caption, f, indent=4)

if __name__ == '__main__':
    main()
