#!/usr/bin/env python3
"""
Semantic Adaptive Reason Inference
"""

import torch
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
import numpy as np
import os
import PIL.Image
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import torchvision
import json
import argparse
import copy
import random
from typing import List, Dict
import time

from semantic_length_predictor import SemanticLengthPredictor, t2i_r1_adaptive_generation

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_caption_height(text, font, img_width, draw):
    """Calculate the height needed for given text at specified width"""
    words = text.split()
    lines = []
    current_line = ""

    for word in words:
        test_line = current_line + " " + word if current_line else word
        text_width = draw.textlength(test_line, font=font)

        if text_width < img_width - 20:
            current_line = test_line
        else:
            lines.append(current_line)
            current_line = word

    if current_line:
        lines.append(current_line)

    try:
        font_size = font.size
    except:
        font_size = font.getsize('X')
        font_size = max(font_size)
    line_height = font_size + 4
    return len(lines) * line_height + 20

def create_grid_with_captions(visual_img, answer_list, save_dir, prompt_text, num_generation):
    os.makedirs(os.path.join(save_dir), exist_ok=True)

    image_tensors = []
    for i in range(num_generation):
        img = Image.fromarray(visual_img[i])
        img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1)
        image_tensors.append(img_tensor)

    nrow = int(np.ceil(np.sqrt(num_generation)))
    grid = torchvision.utils.make_grid(image_tensors, nrow=nrow)
    grid = grid.permute(1, 2, 0).numpy()
    grid = grid.astype(np.uint8)

    os.makedirs(save_dir, exist_ok=True)
    grid_path = os.path.join(save_dir, prompt_text.replace(' ', '_') + ".jpg")
    print(grid_path)
    PIL.Image.fromarray(grid).save(grid_path)
    return grid_path

@torch.inference_mode()
def generate_with_semantic_adaptive(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    prompt_text: str,
    semantic_predictor: SemanticLengthPredictor,
    save_dir: str,
    temperature: float = 1,
    num_generation: int = 9,
    cfg_weight: float = 5,
    image_token_num_per_image: int = 576,
    img_size: int = 384,
    patch_size: int = 16,
    conversation: List[Dict[str, str]] = None,
):
    print(f"\n{'='*60}")
    print(f"🎯 Processing prompt: {prompt_text}")
    print(f"{'='*60}")

    optimal_tokens = semantic_predictor.predict_optimal_length(prompt_text)

    prompt_inputs = vl_chat_processor.tokenizer(
            text=[prompt],
            return_tensors="pt",
            padding=True,
            padding_side="right",
            add_special_tokens=True
    )
    prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

    prompt_ids = prompt_ids.repeat_interleave(num_generation, dim=0).to('cuda')
    prompt_mask = prompt_mask.repeat_interleave(num_generation, dim=0).to('cuda')
    input_embeds = mmgpt.language_model.get_input_embeddings()(prompt_ids)

    if num_generation > 20:
        total_generations = []
        for i in range(prompt_ids.shape[0] // num_generation):
            current_input_embeds = input_embeds[i*num_generation: (i+1)*num_generation]
            current_attn_mask = prompt_mask[i*num_generation: (i+1)*num_generation]
            prompt_completion_ids = mmgpt.language_model.generate(
                inputs_embeds=current_input_embeds,
                attention_mask=current_attn_mask,
                pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
                bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
                eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
                max_new_tokens=optimal_tokens,
                do_sample=True,
                use_cache=True,
            )
            total_generations.append(prompt_completion_ids)
        prompt_completion_ids = torch.cat(total_generations, dim=0)
    else:
        prompt_completion_ids = mmgpt.language_model.generate(
            inputs_embeds=input_embeds,
            attention_mask=prompt_mask,
            pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
            bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
            eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
            max_new_tokens=optimal_tokens,
            do_sample=True,
            use_cache=True,
        )

    prompt_length = prompt_ids.size(1)
    prompt_ids = prompt_ids
    completion_ids = prompt_completion_ids

    image_gen_prompt_list = []

    prompt = vl_chat_processor.tokenizer.decode(prompt_ids[0].cpu().tolist(), skip_special_tokens=True)
    for i in range(completion_ids.shape[0]):
        answer = vl_chat_processor.tokenizer.decode(completion_ids[i].cpu().tolist(), skip_special_tokens=True)
        image_gen_prompt = f"{prompt_text}. {answer}"

        conversation = [
            {
                "role": "User",
                "content": image_gen_prompt,
            },
            {"role": "Assistant", "content": ""},
        ]
        sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
            conversations=conversation,
            sft_format=vl_chat_processor.sft_format,
            system_prompt="",
        )

        print(f"📝 Prompt {i}: {sft_format}")
        print(f"🧠 Semantic-CoT {i}: {answer}")
        print(f"📏 Used length: {optimal_tokens} tokens")
        print("-" * 40)
        image_gen_prompt_list.append(sft_format)

    prompt_inputs = vl_chat_processor.tokenizer(
        text=image_gen_prompt_list,
        return_tensors="pt",
        padding=True,
        padding_side="right",
        add_special_tokens=True,
    )

    prompt_ids, attention_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
    prompt_ids = prompt_ids.to('cuda')
    attention_mask = attention_mask.to('cuda')

    image_start_token_id = vl_chat_processor.tokenizer.encode(vl_chat_processor.image_start_tag)[1]
    prompt_ids = torch.cat([prompt_ids, prompt_ids.new_full((prompt_ids.size(0), 1), image_start_token_id)], dim=1)
    attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.size(0), 1))], dim=1)

    inputs_embeds = mmgpt.language_model.get_input_embeddings()(prompt_ids)
    pad_input_embeds = mmgpt.language_model.get_input_embeddings()(prompt_ids.new_full((1, 1), vl_chat_processor.pad_id))
    total_generated_tokens_img = []

    for j in range(inputs_embeds.shape[0] // num_generation):
        cond_inputs_embeds = inputs_embeds[j*num_generation: (j+1)*num_generation]
        cond_attention_mask = attention_mask[j*num_generation: (j+1)*num_generation]
        uncond_inputs_embeds = cond_inputs_embeds.clone()
        uncond_inputs_embeds[:, 1:-1] = pad_input_embeds

        inputs_embeds_img = torch.repeat_interleave(cond_inputs_embeds, 2, dim=0)
        inputs_embeds_img[1::2] = uncond_inputs_embeds
        attention_mask_img = torch.repeat_interleave(cond_attention_mask, 2, dim=0)
        attention_mask_img[1::2] = torch.ones_like(attention_mask_img[1::2])

        split_size = 2 * num_generation
        for jj in range(0, inputs_embeds_img.shape[0], split_size):
            print(f"🎨 Generating images {jj}")
            start = jj
            end = min(jj + split_size, inputs_embeds_img.shape[0])
            generated_tokens = torch.zeros(((end-start)//2, image_token_num_per_image), dtype=torch.int64).cuda()
            cur_inputs_embeds_img = inputs_embeds_img[start: end]
            cur_attention_mask_img = attention_mask_img[start: end]

            for k in range(image_token_num_per_image):
                outputs = mmgpt.language_model.model(
                    inputs_embeds=cur_inputs_embeds_img,
                    use_cache=True,
                    past_key_values=outputs.past_key_values if k != 0 else None,
                    attention_mask=cur_attention_mask_img
                )

                hidden_states = outputs.last_hidden_state
                logits = mmgpt.gen_head(hidden_states[:, -1, :])
                logit_cond = logits[0::2, :]
                logit_uncond = logits[1::2, :]

                logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
                probs = torch.softmax(logits / temperature, dim=-1)

                next_token = torch.multinomial(probs, num_samples=1)
                generated_tokens[:, k] = next_token.squeeze(dim=-1)

                next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
                img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
                cur_inputs_embeds_img = img_embeds.unsqueeze(dim=1)
                cur_attention_mask_img = torch.cat([cur_attention_mask_img, cur_attention_mask_img.new_ones((cur_attention_mask_img.shape[0], 1), dtype=torch.int)], dim=1)

            print(f"✅ Generated completed, shape: {generated_tokens.shape}")
            total_generated_tokens_img.append(generated_tokens)

    total_generated_tokens_img = torch.cat(total_generated_tokens_img, dim=0)

    dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[num_generation, 8, img_size//patch_size, img_size//patch_size])
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((num_generation, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    create_grid_with_captions(visual_img, image_gen_prompt_list, save_dir, prompt_text, num_generation)

def main():
    parser = argparse.ArgumentParser(description="Semantic Adaptive Reason Inference")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the data directory")
    parser.add_argument("--reasoning_prompt_path", type=str, default="../../../data/prompt/reasoning_prompt.txt")
    parser.add_argument("--save_dir", type=str, default='', help="Path to the save directory")
    parser.add_argument("--num_generation", type=int, default=1)
    parser.add_argument("--save_analysis", action="store_true", help="Save semantic analysis results")

    args = parser.parse_args()

    seed_all(42)

    print("🚀 Loading model...")
    model_path = args.model_path
    vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
    tokenizer = vl_chat_processor.tokenizer

    vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
        model_path, trust_remote_code=True
    )
    vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
    print("✅ Model loaded!")

    print("🧠 Initializing semantic length predictor...")
    semantic_predictor = SemanticLengthPredictor(vl_gpt, vl_chat_processor)
    print("✅ Semantic length predictor initialized!")

    prompt_list = []
    with open(args.data_path, 'r') as f:
        for line in f:
            prompt_list.append(line.strip())

    with open(args.reasoning_prompt_path, 'r') as f:
        cot_prompt = f.read().strip()

    analysis_results = []

    random.shuffle(prompt_list)

    for i, prompt in enumerate(prompt_list):
        prompt_text = copy.deepcopy(prompt)

        print(f"\n🔄 Processing: {i+1}/{len(prompt_list)}")

        start_time = time.time()

        conversation = [
            {
                "role": "User",
                "content": cot_prompt.format(prompt),
            },
            {"role": "Assistant", "content": ""},
        ]

        system_prompt = 'You are a helpful assistant that receives an image prompt and generate a visualization of the prompt.'
        sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
            conversations=conversation,
            sft_format=vl_chat_processor.sft_format,
            system_prompt=system_prompt,
        )
        prompt = sft_format

        generate_with_semantic_adaptive(
            vl_gpt,
            vl_chat_processor,
            prompt,
            prompt_text,
            semantic_predictor,
            args.save_dir,
            num_generation=args.num_generation,
            conversation=conversation
        )

        processing_time = time.time() - start_time

        if args.save_analysis:
            analysis_results.append({
                'prompt': prompt_text,
                'processing_time': processing_time,
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
            })

    if args.save_analysis and analysis_results:
        analysis_file = os.path.join(args.save_dir, "semantic_analysis_results.json") if args.save_dir else "semantic_analysis_results.json"
        with open(analysis_file, 'w', encoding='utf-8') as f:
            json.dump(analysis_results, f, indent=2, ensure_ascii=False)
        print(f"📊 Analysis results saved to: {analysis_file}")

    print(f"\n🎉 All tasks completed! Processed {len(prompt_list)} prompts")

if __name__ == "__main__":
    main()
