from dataclasses import dataclass
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
from diffusers import AutoencoderKL, DDPMScheduler, FlowMatchEulerDiscreteScheduler
from transformers import (
    CLIPTextModel, 
    CLIPTextModelWithProjection,
    CLIPTokenizer, 
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    T5EncoderModel,
    PretrainedConfig,
    T5TokenizerFast,
)
import time
import copy
import os
from io import BytesIO
from trainer.models.base_model import BaseModelConfig
from trainer.models.transformer_sd3_reward import SD3Transformer2DModel

from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from diffusers import StableDiffusion3Pipeline
from diffusers.utils import convert_unet_state_dict_to_peft


from accelerate.logging import get_logger
logger = get_logger(__name__)


def _encode_prompt_with_t5(
    text_encoder,
    tokenizer,
    max_sequence_length,
    prompt=None,
    num_images_per_prompt=1,
    device=None,
    text_input_ids=None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    if tokenizer is not None:
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=max_sequence_length,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids
    else:
        if text_input_ids is None:
            raise ValueError("text_input_ids must be provided when the tokenizer is not specified")

    prompt_embeds = text_encoder(text_input_ids.to(device))[0]

    dtype = text_encoder.dtype
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape

    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds


def _encode_prompt_with_clip(
    text_encoder,
    tokenizer,
    prompt: str,
    device=None,
    text_input_ids=None,
    num_images_per_prompt: int = 1,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    if tokenizer is not None:
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids
    else:
        if text_input_ids is None:
            raise ValueError("text_input_ids must be provided when the tokenizer is not specified")

    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)

    pooled_prompt_embeds = prompt_embeds[0]
    prompt_embeds = prompt_embeds.hidden_states[-2]
    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)

    _, seq_len, _ = prompt_embeds.shape
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds, pooled_prompt_embeds

def get_sigmas(timesteps, scheduler, n_dim=4, device=None, dtype=torch.float32):
    sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
    schedule_timesteps = scheduler.timesteps.to(device)
    timesteps = timesteps.to(device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma

def encode_prompt(
    text_encoders,
    tokenizers,
    prompt: str,
    max_sequence_length,
    device=None,
    num_images_per_prompt: int = 1,
    text_input_ids_list=None,
):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    clip_tokenizers = tokenizers[:2]
    clip_text_encoders = text_encoders[:2]

    clip_prompt_embeds_list = []
    clip_pooled_prompt_embeds_list = []
    for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
        prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            prompt=prompt,
            device=device if device is not None else text_encoder.device,
            num_images_per_prompt=num_images_per_prompt,
            text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
        )
        clip_prompt_embeds_list.append(prompt_embeds)
        clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)

    clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
    pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)

    t5_prompt_embed = _encode_prompt_with_t5(
        text_encoders[-1],
        tokenizers[-1],
        max_sequence_length,
        prompt=prompt,
        num_images_per_prompt=num_images_per_prompt,
        text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
        device=device if device is not None else text_encoders[-1].device,
    )

    clip_prompt_embeds = torch.nn.functional.pad(
        clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
    )
    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)

    return prompt_embeds, pooled_prompt_embeds

def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device):
    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds = encode_prompt(
            text_encoders, tokenizers, prompt, max_sequence_length
        )
        prompt_embeds = prompt_embeds.to(device)
        pooled_prompt_embeds = pooled_prompt_embeds.to(device)
    return prompt_embeds, pooled_prompt_embeds

@dataclass
class SD3BasePreferenceModelConfig(BaseModelConfig):
    _target_: str = "trainer.models.sd3_base_preference_model.sd3_base_preference_model"
    pretrained_model_name_or_path: str = '/SD3.5-medium'
    pretrained_vae_name_or_path: str = 'SD3.5-medium/vae'
    vision_embed_dim: int = 1536
    text_embed_dim: int = 2048
    projection_dim: int = 2048 
    logit_scale_init_value: float = 2.6592  # np.log(1/0.07)
    score_logit_scale_init_value: float = 4.0  # np.log(1/0.07)
    freeze_text_encoder: bool = True
    multi_scale: bool = True
    multi_scale_cfg: bool = False
    guidance_scale: float = 1.0
    noise_offset: bool = False
    noise_offset_coeff: float = 0.05
    total_timesteps: int = 40
    score_model: bool = False


class sd3_base_preference_model(nn.Module):
    def __init__(self, cfg: SD3BasePreferenceModelConfig):
        super().__init__()
        # diffusion models
        # use fp16 vae for sd3

        if not cfg.score_model:
            self.vae = AutoencoderKL.from_pretrained(cfg.pretrained_vae_name_or_path)

            self.text_encoder_one = CLIPTextModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="text_encoder")
            self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="text_encoder_2")
            self.text_encoder_three = T5EncoderModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="text_encoder_3")

            self.tokenizer_one = CLIPTokenizer.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="tokenizer")
            self.tokenizer_two = CLIPTokenizer.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="tokenizer_2")
            self.tokenizer_three = T5TokenizerFast.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="tokenizer_3")

            self.tokenizers = [self.tokenizer_one, self.tokenizer_two, self.tokenizer_three]
            self.text_encoders = [self.text_encoder_one, self.text_encoder_two, self.text_encoder_three]
            self.vae.requires_grad_(False)
            if cfg.freeze_text_encoder:
                self.text_encoder_one.requires_grad_(False)
                self.text_encoder_two.requires_grad_(False)
                self.text_encoder_three.requires_grad_(False)
            
            self.default_sample_size = self.transformer.config.sample_size
            self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

            self.height = self.default_sample_size * self.vae_scale_factor
            self.width = self.default_sample_size * self.vae_scale_factor

            self.val_transform = transforms.Compose(
                [
                    transforms.Resize((self.width, self.height), interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5]),
                ]
            )

        self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
        self.scheduler_copy = copy.deepcopy(self.scheduler)

        self.scheduler_copy.set_timesteps(cfg.total_timesteps)
        self.timesteps = self.scheduler_copy.timesteps

        self.transformer = SD3Transformer2DModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="transformer")

        target_modules = [
            "attn.add_k_proj",
            "attn.add_q_proj",
            "attn.add_v_proj",
            "attn.to_add_out",
            "attn.to_k",
            "attn.to_out.0",
            "attn.to_q",
            "attn.to_v",
        ]
        transformer_lora_config = LoraConfig(
            r=32,
            lora_alpha=32,
            init_lora_weights="gaussian",
            target_modules=target_modules,
        )
        self.transformer.add_adapter(transformer_lora_config)
        # self.pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
        # self.image_processor = CLIPImageProcessor.from_pretrained(pretrained_model_name_or_path)
        
        self.cfg = cfg
        
        # projection layers
        if cfg.multi_scale:
            self.visual_projection = nn.Linear(3 * cfg.vision_embed_dim, cfg.projection_dim, bias=False)
            pass # TODO: multi layers' projector
        else:  
            self.visual_projection = nn.Linear(cfg.vision_embed_dim, cfg.projection_dim, bias=False)
        nn.init.normal_(self.visual_projection.weight, std=0.02)
        
        self.logit_scale = nn.Parameter(torch.ones([]) * cfg.logit_scale_init_value)
        self.score_logit_scale = nn.Parameter(torch.ones([]) * cfg.score_logit_scale_init_value)

        self.do_classifier_free_guidance = self.cfg.guidance_scale > 1.0

        if self.do_classifier_free_guidance:
            # generate negative prompt ids
            self.neg_prompt = ""


    def get_text_features(self, prompt, device):
        
        if self.do_classifier_free_guidance:
            prompt = prompt + [self.neg_prompt]*len(prompt)
            # text_input_ids = torch.cat([text_input_ids, self.neg_prompt_ids.repeat(text_input_ids.shape[0], 1).to(text_input_ids.device)], dim=0)
            # text_input_ids_2 = torch.cat([text_input_ids_2, self.neg_prompt_ids_2.repeat(text_input_ids_2.shape[0], 1).to(text_input_ids_2.device)], dim=0)

        prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
            prompt, self.text_encoders, self.tokenizers, self.tokenizers[0].model_max_length, device
        )
        pooled_output = pooled_prompt_embeds
        
        pooled_output_for_time = pooled_prompt_embeds
        encoder_hidden_states = prompt_embeds  # b,l,4096
        
        if self.do_classifier_free_guidance:
            pooled_output_text, pooled_output_ucond = pooled_output.chunk(2, dim=0)
        else:
            pooled_output_text = pooled_output
        return encoder_hidden_states, pooled_output_for_time, pooled_output_text

    
    def get_image_features(self, encoder_hidden_states, pooled_projections, image_inputs, u, generator=None):
        with torch.no_grad():
            latents = self.vae.encode(image_inputs).latent_dist.sample()
            latents = latents * self.vae.config.scaling_factor

        if generator is not None:
            noise = torch.randn(latents.size(), generator=generator, dtype=latents.dtype, device=latents.device)
        else:
            noise = torch.randn_like(latents)

        if self.cfg.noise_offset: 
            # https://www.crosslabs.org//blog/diffusion-with-offset-noise
            noise += self.cfg.noise_offset_coeff * torch.randn(
                (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
            )
        
        bsz = latents.shape[0]


        timesteps = self.timesteps[u.cpu()].to(device=latents.device)
        sigmas = get_sigmas(timesteps, scheduler=self.scheduler_copy, n_dim=latents.ndim, device =latents.device, dtype=latents.dtype)
        noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
        if self.do_classifier_free_guidance:
            noisy_latents = torch.cat([noisy_latents] * 2, dim=0) # latent_text0, latent_text1, latent_ucond0, latent_ucond1
            time_cond = torch.cat([timesteps] * 2, dim=0)
        else:
            time_cond = timesteps

        pred_v, block_res_samples = self.transformer(
            hidden_states=noisy_latents,
            timestep=time_cond,
            encoder_hidden_states=encoder_hidden_states,
            pooled_projections=pooled_projections,
            return_dict=False,
        )

        v = (noise - latents).to(self.transformer.dtype)

        if self.cfg.multi_scale:
            first_stage_output = block_res_samples[13].mean(dim=1)
            second_stage_output = block_res_samples[23].mean(dim=1)
            third_stage_output = block_res_samples[-1].mean(dim=1)
            if self.do_classifier_free_guidance:
                pass
            concat_pooled_output = torch.cat([first_stage_output, second_stage_output, third_stage_output], dim=-1)
            image_features = self.visual_projection(concat_pooled_output)
        else:
            output = block_res_samples[-1].mean(dim=1)
            if self.do_classifier_free_guidance:
                output_text, output_ucond = output.chunk(2, dim=0)
                output = output_ucond + self.cfg.guidance_scale * (output_text - output_ucond)
            image_features = self.visual_projection(output)

        if self.do_classifier_free_guidance:
            pred_v, pred_v_ucond = pred_v.chunk(2, dim=0)
            pred_v = pred_v_ucond + self.cfg.guidance_scale * (pred_v - pred_v_ucond)

            v, v_ucond = v.chunk(2, dim=0)
            v = v_ucond + self.cfg.guidance_scale * (v - v_ucond)
        
        # prediction = torch.nn.functional.mse_loss(pred_v, v, reduction="none").mean((1,2,3))
        prediction = torch.mean(((pred_v.float() - v.float()) ** 2).reshape(v.shape[0], -1), 1)

        return image_features, prediction

    def forward(self, prompt, image_inputs, u, generator=None):
        n_p = len(prompt) # text_input_ids.shape[0]
        n_i = image_inputs.shape[0]
        outputs = ()
        
        encoder_hidden_states, pooled_output, text_features = self.get_text_features(prompt, image_inputs.device)

        outputs += text_features,

        if n_i == 2 * n_p:
            if self.do_classifier_free_guidance:
                encoder_hidden_states_text, encoder_hidden_states_ucond = encoder_hidden_states.chunk(2, dim=0)
                encoder_hidden_states = torch.cat([encoder_hidden_states_text] * 2 + [encoder_hidden_states_ucond] * 2, dim=0)
                pooled_output_text, pooled_output_ucond = pooled_output.chunk(2, dim=0)
                pooled_output = torch.cat([pooled_output_text] * 2 + [pooled_output_ucond] * 2, dim=0)
            else:
                encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0)
                pooled_output = torch.cat([pooled_output, pooled_output], dim=0)
        image_features, prediction = self.get_image_features(encoder_hidden_states, pooled_output, image_inputs, u, generator=generator)
        outputs += image_features,
        outputs += prediction,
            
        return outputs

    def save(self, path):

        # save others
        state_dict = {
            'visual_projection': self.visual_projection.state_dict(),
            # 'text_projection': self.text_projection.state_dict(),
            'score_logit_scale': self.score_logit_scale.data.item(),
            'logit_scale': self.logit_scale.data.item(),
        }
        torch.save(state_dict, os.path.join(path, "state_dict.pt"))
        # logger.info(f"Save model to path {path} successfully")


    def load(self, path):
        lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(path)
        transformer_state_dict = {
            f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
        }
        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
        incompatible_keys = set_peft_model_state_dict(self.transformer, transformer_state_dict, adapter_name="default")
        if incompatible_keys is not None:
            unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
            if unexpected_keys:
                logger.warning(
                    f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                    f" {unexpected_keys}. "
                )


        if not self.cfg.freeze_text_encoder:
            self.text_encoder = self.text_encoder.from_pretrained(os.path.join(path, "text_encoder"))
            # logger.info(f"Loading text_encoder weights from {os.path.join(path, 'text_encoder')}")
            self.text_encoder_2 = self.text_encoder_2.from_pretrained(os.path.join(path, "text_encoder_2"))
            # logger.info(f"Loading text_encoder_2 weights from {os.path.join(path, 'text_encoder_2')}")
            
        # load others
        state_dict = torch.load(os.path.join(path, "state_dict.pt"))
        self.visual_projection.load_state_dict(state_dict['visual_projection'])
        self.logit_scale.data = torch.tensor(state_dict['logit_scale'])      
        self.score_logit_scale.data = torch.tensor(state_dict['score_logit_scale'])
        # logger.info(f"Loading projection and logit_scale weights from {os.path.join(path, 'state_dict.pt')}")

    def preprocess_image(self, images):
        if not isinstance(images, list):
            images = [images]

        image_inputs = []
        for image in images:
            if isinstance(image, dict):
                image = image["bytes"]
            if isinstance(image, bytes):
                image = Image.open(BytesIO(image))
            elif isinstance(image, str):
                image = Image.open(image)
            image = image.convert("RGB")
            image = self.val_transform(image)     
            image_inputs.append(image)   
        image_inputs = torch.stack(image_inputs, dim=0)
        return image_inputs


    def get_preference_scores(self, images, prompt, u, generator=None):
        n_p = len(prompt) # text_input_ids.shape[0]
        n_i = images.shape[0]
        encoder_hidden_states, pooled_output, text_features = self.get_text_features(prompt, images.device)

        if n_i == 2 * n_p:
            if self.do_classifier_free_guidance:
                encoder_hidden_states_text, encoder_hidden_states_ucond = encoder_hidden_states.chunk(2, dim=0)
                encoder_hidden_states = torch.cat([encoder_hidden_states_text] * 2 + [encoder_hidden_states_ucond] * 2, dim=0)
                pooled_output_text, pooled_output_ucond = pooled_output.chunk(2, dim=0)
                pooled_output = torch.cat([pooled_output_text] * 2 + [pooled_output_ucond] * 2, dim=0)
            else:
                encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0)
                pooled_output = torch.cat([pooled_output, pooled_output], dim=0)
                # 设定require_grad为True
        image_features, prediction = self.get_image_features(encoder_hidden_states, pooled_output, images, u, generator=generator)
        
        image_embs = image_features / torch.norm(image_features, dim=-1, keepdim=True)
        text_embs = text_features / torch.norm(text_features, dim=-1, keepdim=True)
        scores = self.logit_scale.exp() * (text_embs @ image_embs.T)
        
        return scores


    def get_latent_features(self, encoder_hidden_states, pooled_projections, latents, u, generator = None):
        bsz = latents.shape[0]

        timesteps = self.timesteps[u.cpu()].to(device=latents.device)
        sigmas = get_sigmas(timesteps, scheduler=self.scheduler_copy, n_dim=latents.ndim, device =latents.device, dtype=latents.dtype)

        if self.do_classifier_free_guidance:
            latents = torch.cat([latents] * 2, dim=0) # latent_text0, latent_text1, latent_ucond0, latent_ucond1
            time_cond = torch.cat([timesteps] * 2, dim=0)
        else:
            time_cond = timesteps

        pred_v, block_res_samples = self.transformer(
            hidden_states=latents,
            timestep=time_cond,
            encoder_hidden_states=encoder_hidden_states,
            pooled_projections=pooled_projections,
            return_dict=False,
        )

        if self.cfg.multi_scale:
            first_stage_output = block_res_samples[13].mean(dim=1)
            second_stage_output = block_res_samples[23].mean(dim=1)
            third_stage_output = block_res_samples[-1].mean(dim=1)
            if self.do_classifier_free_guidance:
                pass
            concat_pooled_output = torch.cat([first_stage_output, second_stage_output, third_stage_output], dim=-1)
            image_features = self.visual_projection(concat_pooled_output)
        else:
            output = block_res_samples[-1].mean(dim=1)
            if self.do_classifier_free_guidance:
                output_text, output_ucond = output.chunk(2, dim=0)
                output = output_ucond + self.cfg.guidance_scale * (output_text - output_ucond)
            image_features = self.visual_projection(output)
        
        # prediction = torch.nn.functional.mse_loss(pred_v, v, reduction="none").mean((1,2,3))
        return image_features

    def get_latent_preference_scores(self, reward_prompt, latents, u, generator=None):
        prompt_embeds, pooled_prompt_embeds = reward_prompt
        prompt_embeds = prompt_embeds.to(latents.device)
        pooled_prompt_embeds = pooled_prompt_embeds.to(latents.device)
        encoder_hidden_states, pooled_output, text_features = prompt_embeds, pooled_prompt_embeds, pooled_prompt_embeds

        n_p = encoder_hidden_states.shape[0] # text_input_ids.shape[0]
        n_i = latents.shape[0]

        if n_i == 2 * n_p:
            if self.do_classifier_free_guidance:
                encoder_hidden_states_text, encoder_hidden_states_ucond = encoder_hidden_states.chunk(2, dim=0)
                encoder_hidden_states = torch.cat([encoder_hidden_states_text] * 2 + [encoder_hidden_states_ucond] * 2, dim=0)
                pooled_output_text, pooled_output_ucond = pooled_output.chunk(2, dim=0)
                pooled_output = torch.cat([pooled_output_text] * 2 + [pooled_output_ucond] * 2, dim=0)
            else:
                encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0)
                pooled_output = torch.cat([pooled_output, pooled_output], dim=0)
        latent_features = self.get_latent_features(encoder_hidden_states, pooled_output, latents, u, generator=generator)

        latent_features = latent_features / torch.norm(latent_features, dim=-1, keepdim=True)
        text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True)
        scores = self.logit_scale.exp() * (text_features @ latent_features.T)
        scores = torch.diag(scores)

        return scores