# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from concurrent.futures import Executor
from typing import List, Dict, Tuple, Optional, Union, Any, Callable    
import matplotlib.pyplot as plt
import time
from contextlib import contextmanager
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from data.data_utils import (
    pil_img2rgb, len2weight, 
    get_flattened_position_ids_extrapolate, get_flattened_position_ids_interpolate, 
    patchify,
)
import math
from modeling.bagel.qwen2_navit import NaiveCache, NaiveCacheMultiSeq
from modeling.bagel.bagel import Bagel
import random
from tqdm import tqdm
import re
from dataclasses import dataclass, field
from train.general_llm_server import GeneralLLMServer
import torch.distributed as dist

VLM_THINK_SYSTEM_PROMPT = '''You should first think about the reasoning process in the mind and then provide the user with the answer. 
The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here'''

GEN_THINK_SYSTEM_PROMPT = '''You should first think about the planning process in the mind and then generate the image. 
The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here.'''


class MultiRoundRolloutController:
    def __init__(self, model:Bagel, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids, sde_sampler):
        self.model = model
        self.vae_model = vae_model
        self.tokenizer = tokenizer
        self.vae_transform = vae_transform
        self.vit_transform = vit_transform
        self.new_token_ids = new_token_ids
        self.sde_sampler = sde_sampler
        
    @property
    def unwrapped_model(self):
        if isinstance(self.model, FSDP):
            return self.model.module
        else:
            return self.model
        
    def init_gen_context(self, batch_size:int = 1, multi_seq: bool = False,): 
        num_layers = self.model.config.llm_config.num_hidden_layers
        gen_context = {
            'kv_lens': [0] * batch_size,
            'ropes': [0] * batch_size,
            'past_key_values': NaiveCacheMultiSeq(num_layers, batch_size) if multi_seq else NaiveCache(num_layers),
        }
        return gen_context

    @torch.no_grad()
    def update_context_text(self, text, gen_context):
        # used for interleave data, currently only support 1 data inference, 
        past_key_values = gen_context['past_key_values']
        kv_lens = gen_context['kv_lens']
        ropes = gen_context['ropes']
        generation_input, kv_lens, ropes = self.model.prepare_prompts(
            curr_kvlens=kv_lens,
            curr_rope=ropes, 
            prompts=text if isinstance(text, list) else [text],
            tokenizer=self.tokenizer, 
            new_token_ids=self.new_token_ids,
        )
        selected_cache_indices = torch.arange(0, len(kv_lens), device=self.model.device)

        past_key_values = self.model.forward_cache_update_text(past_key_values, 
                                                               selected_cache_indices=selected_cache_indices, 
                                                               **generation_input)     
        gen_context['kv_lens'] = kv_lens
        gen_context['ropes'] = ropes
        gen_context['past_key_values'] = past_key_values        
        return gen_context

    @torch.no_grad()
    def update_context_image(self, image, gen_context, vae=True, vit=True):
        # used for interleave data, currently only support 1 data inference, 
        assert vae or vit
        past_key_values = gen_context['past_key_values']
        kv_lens = gen_context['kv_lens']
        ropes =  gen_context['ropes']

        if vae:
            ## update vae
            generation_input, kv_lens, ropes = self.model.prepare_vae_images(
                curr_kvlens=kv_lens,
                curr_rope=ropes, 
                images=image if isinstance(image, list) else [image],
                transforms=self.vae_transform, 
                new_token_ids=self.new_token_ids,
            )
            selected_cache_indices = torch.arange(0, len(kv_lens), device=self.model.device)
            past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, 
                            selected_cache_indices=selected_cache_indices, **generation_input)
        
        if vit:
            ## update vit
            generation_input, kv_lens, ropes = self.model.prepare_vit_images(
                curr_kvlens=kv_lens,
                curr_rope=ropes, 
                images=image if isinstance(image, list) else [image],
                transforms=self.vit_transform, 
                new_token_ids=self.new_token_ids,
            )
            selected_cache_indices = torch.arange(0, len(kv_lens), device=self.model.device)
            past_key_values = self.model.forward_cache_update_vit(past_key_values, 
                            selected_cache_indices=selected_cache_indices, **generation_input)

        gen_context['kv_lens'] = kv_lens
        gen_context['ropes'] = ropes
        gen_context['past_key_values'] = past_key_values
        
        return gen_context

    @torch.no_grad()
    def gen_image(
        self, 
        image_shape, 
        gen_context, 
        batch_size: int = 1,
        cfg_text_scale=4.0,
        cfg_img_scale=1.5,
        cfg_text_precontext=None, 
        cfg_img_precontext=None, 
        cfg_interval=(0.4, 1.0),
        cfg_renorm_min=0.0,
        cfg_renorm_type="global",
        num_timesteps=50, 
        timestep_shift=3.0,
        initial_noise=None,
        enable_sde=False,
        sde_timestep_idx=None,
    ):
        past_key_values = gen_context['past_key_values']
        kv_lens = gen_context['kv_lens']
        ropes = gen_context['ropes']
        generation_input = self.model.prepare_vae_latent(
            curr_kvlens=kv_lens,
            curr_rope=ropes, 
            image_sizes=[image_shape] if isinstance(image_shape, tuple) else image_shape, 
            new_token_ids=self.new_token_ids,
            initial_noise=initial_noise
        ) 
        
        # text cfg
        cfg_text_past_key_values = cfg_text_precontext['past_key_values']
        kv_lens_cfg = cfg_text_precontext['kv_lens']
        ropes_cfg = cfg_text_precontext['ropes']
        generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
            curr_kvlens=kv_lens_cfg,
            curr_rope=ropes_cfg, 
            image_sizes=[image_shape] if isinstance(image_shape, tuple) else image_shape, 
        )

        # img cfg
        cfg_img_past_key_values = cfg_img_precontext['past_key_values']
        kv_lens_cfg = cfg_img_precontext['kv_lens']
        ropes_cfg = cfg_img_precontext['ropes']
        generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
            curr_kvlens=kv_lens_cfg,
            curr_rope=ropes_cfg, 
            image_sizes=[image_shape] if isinstance(image_shape, tuple) else image_shape, 
        )

        if self.sde_sampler is None or not enable_sde:
            unpacked_latent = self.model.generate_image(
                past_key_values=past_key_values,
                cfg_text_past_key_values=cfg_text_past_key_values,
                cfg_img_past_key_values=cfg_img_past_key_values,
                num_timesteps=num_timesteps,
                cfg_text_scale=cfg_text_scale,
                cfg_img_scale=cfg_img_scale,
                cfg_interval=cfg_interval,
                cfg_renorm_min=cfg_renorm_min,
                cfg_renorm_type=cfg_renorm_type,
                timestep_shift=timestep_shift,
                **generation_input,
                cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
                cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
                cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
                cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
                cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
                cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
                cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
                cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
            )
            image = self.decode_image(unpacked_latent, image_shape, batch_size=batch_size)
            return image, [], [], [], []
        
        unpacked_latent, latents, log_probs, timesteps, dts = self.model.generate_image_mix(
            past_key_values=past_key_values,
            cfg_text_past_key_values=cfg_text_past_key_values,
            cfg_img_past_key_values=cfg_img_past_key_values,
            num_timesteps=num_timesteps,
            cfg_text_scale=cfg_text_scale,
            cfg_img_scale=cfg_img_scale,
            cfg_interval=cfg_interval,
            cfg_renorm_min=cfg_renorm_min,
            cfg_renorm_type=cfg_renorm_type,
            timestep_shift=timestep_shift,
            sde_sampler=self.sde_sampler,
            **generation_input,
            cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
            cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
            cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
            cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
            cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
            cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
            cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
            cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
            sde_timesteps_idx=sde_timestep_idx,
        )

        image = self.decode_image(unpacked_latent, image_shape, batch_size=batch_size)
        return image, latents, log_probs, timesteps, dts

    def decode_image(self, latent, image_shape, batch_size: int = 1):
        H, W = image_shape if isinstance(image_shape, tuple) else image_shape[0]
        h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
        latent = torch.cat(latent, dim=0)
        latent = latent.reshape(batch_size, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel)
        latent = torch.einsum("nhwpqc->nchpwq", latent)
        latent = latent.reshape(batch_size, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size)
        image = self.vae_model.decode(latent)
        image = (image * 0.5 + 0.5).clamp(0, 1).permute(0, 2, 3, 1) * 255
        images = []
        for i in range(batch_size):
            images.append((image[i]).to(torch.uint8).cpu().numpy())
        return images

    @torch.no_grad()
    def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0):
        gen_context = deepcopy(gen_context)
        past_key_values = gen_context['past_key_values']
        kv_lens = gen_context['kv_lens']
        ropes = gen_context['ropes']

        generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
        unpacked_latent = self.model.generate_text(
            past_key_values=past_key_values,
            max_length=max_length,
            do_sample=do_sample,
            temperature=temperature,
            end_token_id=self.new_token_ids['eos_token_id'],
            **generation_input,
        )
        output = self.tokenizer.decode(unpacked_latent[:,0])
        output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
        return output