from main.utils import get_x0_from_noise, DummyNetwork, NoOpContext
from diffusers import UNet2DConditionModel, DDIMScheduler
from main.sd_unet_forward import classify_forward
import torch.nn.functional as F
import torch.nn as nn
import torch
import types 
from segment_anything import sam_model_registry
from typing import Dict, List
from torch.utils.checkpoint import checkpoint
def create_2d_head(in_channels: int = 1280, out_channels: int = 256, num_blocks: int = 3):
    """
    Create a 2D head as shown in the diagram.
    """
    layers = []

    # 第一层：3x3 Conv + GN + SiLU
    layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1))
    layers.append(nn.GroupNorm(num_groups=32, num_channels=out_channels))
    layers.append(nn.SiLU())

    # 重复三次：4x4 Conv + GN + SiLU
    for _ in range(num_blocks):
        layers.append(nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1))
        layers.append(nn.GroupNorm(num_groups=32, num_channels=out_channels))
        layers.append(nn.SiLU())

    # 最后：Global Avg Pooling + Flatten + Linear
    layers.append(nn.AdaptiveAvgPool2d((1, 1)))
    layers.append(nn.Flatten())
    layers.append(nn.Linear(out_channels, 1))

    return nn.Sequential(*layers)

class SAMDiscriminator(nn.Module):
    def __init__(self, sam_model_path: str, model_type: str = "vit_h", device='cuda', dtype=torch.float16):
        super().__init__()
        
        # self.hook_layers = [4, 15, 23, 31]
        self.hook_layers = [3, 7, 11, 15, 19, 23, 27, 31]
        
        print(f"Initializing SAM Discriminator, hooking into layers: {self.hook_layers}")

        sam = sam_model_registry[model_type](checkpoint=sam_model_path).to(device,dtype)
        self.preprocess = sam.preprocess
        self.sam_encoder = sam.image_encoder
        
        for param in self.sam_encoder.parameters():
            param.requires_grad = False
        
        self.headers = nn.ModuleList()
        
        feature_dim = self.sam_encoder.neck[0].in_channels # 1280

        for _ in self.hook_layers:
            header = create_2d_head(in_channels=1280, out_channels=256, num_blocks=3).to(device, dtype)
            self.headers.append(header)
        self.fusion_head = nn.Linear(len(self.headers), 1).to(device, dtype)
        self.features: Dict[int, torch.Tensor] = {}
        self._create_hooks()

    def _hook_fn(self, layer_index):
        def fn(_, __, output):
            self.features[layer_index] = output[0] if isinstance(output, tuple) else output
        return fn

    def _create_hooks(self):
        for i in self.hook_layers:
            self.sam_encoder.blocks[i].register_forward_hook(self._hook_fn(i))
            
    def forward(self, x: torch.Tensor, no_grad=False) -> torch.Tensor:
        self.features.clear()
        x = self.preprocess(x)
        assert (
            len(x.shape) == 4
            and x.shape[1] == 3
            and max(*x.shape[2:]) == self.sam_encoder.img_size
        )
        if not no_grad:
            self.sam_encoder(x)  
        else:
            with torch.no_grad():
                self.sam_encoder(x)  
        outputs = []
        for i, header in zip(self.hook_layers, self.headers):
            # (B, Patch_Num, Patch_Num, Feature_Dim)
            layer_features = self.features[i]
            layer_features = layer_features.permute(0,3,1,2)
            score = header(layer_features)
            outputs.append(score)
        
        final_score = self.fusion_head(torch.cat(outputs,dim=1)) # B,4 -> B,1
        return final_score
    
def predict_noise(unet, noisy_latents, text_embeddings, uncond_embedding, timesteps, 
    guidance_scale=1.0, unet_added_conditions=None, uncond_unet_added_conditions=None
):
    CFG_GUIDANCE = guidance_scale > 1

    if CFG_GUIDANCE:
        model_input = torch.cat([noisy_latents] * 2) 
        embeddings = torch.cat([uncond_embedding, text_embeddings]) 
        timesteps = torch.cat([timesteps] * 2) 

        if unet_added_conditions is not None:
            assert uncond_unet_added_conditions is not None 
            condition_input = {}
            for key in unet_added_conditions.keys():
                condition_input[key] = torch.cat(
                    [uncond_unet_added_conditions[key], unet_added_conditions[key]] # should be uncond, cond, check the order  
                )
        else:
            condition_input = None 

        noise_pred = unet(model_input, timesteps, embeddings, added_cond_kwargs=condition_input).sample
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 
    else:
        model_input = noisy_latents 
        embeddings = text_embeddings
        timesteps = timesteps    
        noise_pred = unet(model_input, timesteps, embeddings, added_cond_kwargs=unet_added_conditions).sample

    return noise_pred  

class SDGuidance(nn.Module):
    def __init__(self, args, accelerator):
        super().__init__()
        self.args = args 
        self.real_unet = UNet2DConditionModel.from_pretrained(
            args.model_id,
            subfolder="unet"
        ).float()

        self.real_unet.requires_grad_(False)
        self.gan_alone = args.gan_alone 

        self.fake_unet = UNet2DConditionModel.from_pretrained(
            args.model_id,
            subfolder="unet"
        ).float()

        self.fake_unet.requires_grad_(True)

        # somehow FSDP requires at least one network with dense parameters (models from diffuser are lazy initialized so their parameters are empty in fsdp mode)
        self.dummy_network = DummyNetwork() 
        self.dummy_network.requires_grad_(False)

        # we move real unet to half precision
        # as we don't backpropagate through it
        if args.use_fp16:
            self.real_unet = self.real_unet.to(torch.bfloat16)
            self.fake_unet = self.fake_unet.to(torch.bfloat16)
            
        if self.gan_alone:
            del self.real_unet
            

        self.scheduler = DDIMScheduler.from_pretrained(
            args.model_id,
            subfolder="scheduler"
        )
        
        alphas_cumprod = self.scheduler.alphas_cumprod
        self.register_buffer(
            "alphas_cumprod",
            alphas_cumprod
        )

        self.num_train_timesteps = args.num_train_timesteps 
        self.min_step = int(args.min_step_percent * self.scheduler.num_train_timesteps)
        self.max_step = int(args.max_step_percent * self.scheduler.num_train_timesteps)
        
        self.real_guidance_scale = args.real_guidance_scale 
        self.fake_guidance_scale = args.fake_guidance_scale

        assert self.fake_guidance_scale == 1, "no guidance for fake"

        self.use_fp16 = args.use_fp16

        self.cls_on_clean_image = args.cls_on_clean_image 
        self.gen_cls_loss = args.gen_cls_loss 

        self.accelerator = accelerator

        if self.cls_on_clean_image:
            self.discriminator = SAMDiscriminator(sam_model_path='/root/cgj/models/sam_vit_h_4b8939.pth', device=accelerator.device, dtype=torch.bfloat16)

        self.sdxl = args.sdxl 
        self.gradient_checkpointing = args.gradient_checkpointing 

        self.diffusion_gan = args.diffusion_gan 
        self.diffusion_gan_max_timestep = args.diffusion_gan_max_timestep

        self.network_context_manager = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if self.use_fp16 else NoOpContext()
        self.vae_context_manager = torch.autocast(device_type="cuda", dtype=torch.float32)
        self.vae_fp16_context_manager = torch.autocast(device_type="cuda", dtype=torch.float16)



    def compute_cls_logits(self, image_latent):
        image_latent_noised = image_latent
        # timesteps = torch.randint(
        #     0, self.diffusion_gan_max_timestep, [image_latent.shape[0]], device=image_latent.device, dtype=torch.long
        # )
        # image_latent_noised = self.scheduler.add_noise(image_latent, torch.randn_like(image_latent), timesteps)
        with self.vae_context_manager:
            def decode_latent(vae, latent):
                return vae.decode(1 / self.vae.config.scaling_factor * latent).sample.float()
            image = checkpoint(
                decode_latent,
                self.vae,
                image_latent_noised
            )
            
        logits = self.discriminator(image)
        pil_image = None   
        
        # with torch.no_grad():
        #     decoded = (image + 1) / 2
        #     decoded = decoded.clamp(0, 1)
        #     image_tensor = decoded[0].permute(1, 2, 0).cpu().numpy()  # [3, H, W] → [H, W, 3]
        #     image_uint8 = (image_tensor * 255).astype("uint8")
        #     from PIL import Image
        #     pil_image = Image.fromarray(image_uint8)
            
        return logits, pil_image # pil_image

    def compute_cls_logits_no_grad(self, image_latent):
        image_latent_noised = image_latent
        # timesteps = torch.randint(
        #     0, self.diffusion_gan_max_timestep, [image_latent.shape[0]], device=image_latent.device, dtype=torch.long
        # )
        # image_latent_noised = self.scheduler.add_noise(image_latent, torch.randn_like(image_latent), timesteps)
        with torch.no_grad():
            with self.vae_fp16_context_manager:
                def decode_latent(vae, latent):
                    return vae.decode(1 / self.vae.config.scaling_factor * latent).sample.float()
                image = checkpoint(
                    decode_latent,
                    self.vae,
                    image_latent_noised
                )
        
        logits = self.discriminator(image, no_grad=True)
        pil_image = None
        
        # with torch.no_grad():
        #     decoded = (image + 1) / 2
        #     decoded = decoded.clamp(0, 1)
        #     image_tensor = decoded[0].permute(1, 2, 0).cpu().numpy()  # [3, H, W] → [H, W, 3]
        #     image_uint8 = (image_tensor * 255).astype("uint8")
        #     from PIL import Image
        #     pil_image = Image.fromarray(image_uint8)
            
        
        return logits, pil_image
    
    def compute_distribution_matching_loss(
        self, 
        latents,
        text_embedding,
        uncond_embedding,
        unet_added_conditions,
        uncond_unet_added_conditions
    ):
        original_latents = latents 
        batch_size = latents.shape[0]
        with torch.no_grad():
            timesteps = torch.randint(
                self.min_step, 
                min(self.max_step+1, self.num_train_timesteps),
                [batch_size], 
                device=latents.device,
                dtype=torch.long
            )

            noise = torch.randn_like(latents)

            noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)

            # run at full precision as autocast and no_grad doesn't work well together 
            pred_fake_noise = predict_noise(
                self.fake_unet, noisy_latents, text_embedding, uncond_embedding, 
                timesteps, guidance_scale=self.fake_guidance_scale,
                unet_added_conditions=unet_added_conditions,
                uncond_unet_added_conditions=uncond_unet_added_conditions
            )  

            pred_fake_image = get_x0_from_noise(
                noisy_latents.double(), pred_fake_noise.double(), self.alphas_cumprod.double(), timesteps
            )

            if self.use_fp16:
                if self.sdxl:
                    bf16_unet_added_conditions = {} 
                    bf16_uncond_unet_added_conditions = {} 

                    for k,v in unet_added_conditions.items():
                        bf16_unet_added_conditions[k] = v.to(torch.bfloat16)
                    for k,v in uncond_unet_added_conditions.items():
                        bf16_uncond_unet_added_conditions[k] = v.to(torch.bfloat16)
                else:
                    bf16_unet_added_conditions = unet_added_conditions 
                    bf16_uncond_unet_added_conditions = uncond_unet_added_conditions

                pred_real_noise = predict_noise(
                    self.real_unet, noisy_latents.to(torch.bfloat16), text_embedding.to(torch.bfloat16), 
                    uncond_embedding.to(torch.bfloat16), 
                    timesteps, guidance_scale=self.real_guidance_scale,
                    unet_added_conditions=bf16_unet_added_conditions,
                    uncond_unet_added_conditions=bf16_uncond_unet_added_conditions
                ) 
            else:
                pred_real_noise = predict_noise(
                    self.real_unet, noisy_latents, text_embedding, uncond_embedding, 
                    timesteps, guidance_scale=self.real_guidance_scale,
                    unet_added_conditions=unet_added_conditions,
                    uncond_unet_added_conditions=uncond_unet_added_conditions
                )

            pred_real_image = get_x0_from_noise(
                noisy_latents.double(), pred_real_noise.double(), self.alphas_cumprod.double(), timesteps
            )     

            p_real = (latents - pred_real_image)
            p_fake = (latents - pred_fake_image)

            grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1, 2, 3], keepdim=True) 
            grad = torch.nan_to_num(grad)

        loss = 0.5 * F.mse_loss(original_latents.float(), (original_latents-grad).detach().float(), reduction="mean")         

        loss_dict = {
            "loss_dm": loss 
        }

        dm_log_dict = {
            "dmtrain_noisy_latents": noisy_latents.detach().float(),
            "dmtrain_pred_real_image": pred_real_image.detach().float(),
            "dmtrain_pred_fake_image": pred_fake_image.detach().float(),
            "dmtrain_grad": grad.detach().float(),
            "dmtrain_gradient_norm": torch.norm(grad).item()
        }

        return loss_dict, dm_log_dict

    def compute_loss_fake(
        self,
        latents,
        text_embedding,
        uncond_embedding,
        unet_added_conditions=None,
        uncond_unet_added_conditions=None
    ):
        if self.gradient_checkpointing:
            self.fake_unet.enable_gradient_checkpointing()
        latents = latents.detach()
        batch_size = latents.shape[0]
        noise = torch.randn_like(latents)

        timesteps = torch.randint(
            0,
            self.num_train_timesteps,
            [batch_size], 
            device=latents.device,
            dtype=torch.long
        )
        noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)

        with self.network_context_manager:
            fake_noise_pred = predict_noise(
                self.fake_unet, noisy_latents, text_embedding, uncond_embedding,
                timesteps, guidance_scale=1, # no guidance for training dfake 
                unet_added_conditions=unet_added_conditions,
                uncond_unet_added_conditions=uncond_unet_added_conditions
            )

        fake_noise_pred = fake_noise_pred.float()

        fake_x0_pred = get_x0_from_noise(
            noisy_latents.double(), fake_noise_pred.double(), self.alphas_cumprod.double(), timesteps
        )

        # epsilon prediction loss 
        loss_fake = torch.mean(
            (fake_noise_pred.float() - noise.float())**2
        )

        loss_dict = {
            "loss_fake_mean": loss_fake,
        }

        fake_log_dict = {
            "faketrain_latents": latents.detach().float(),
            "faketrain_noisy_latents": noisy_latents.detach().float(),
            "faketrain_x0_pred": fake_x0_pred.detach().float()
        }
        if self.gradient_checkpointing:
            self.fake_unet.disable_gradient_checkpointing()
        return loss_dict, fake_log_dict

    def compute_generator_clean_cls_loss(self, 
        fake_image, 
    ):
        loss_dict = {} 
        
        # fake_image is latent

        pred_realism_on_fake_with_grad, pil_image = self.compute_cls_logits(fake_image)
        # pil_image.save('gen-fake.png')
        # breakpoint()
        loss_dict["gen_cls_loss"] = F.softplus(-pred_realism_on_fake_with_grad).mean()
        return loss_dict 

    def generator_forward(
        self,
        image,
        text_embedding,
        uncond_embedding,
        unet_added_conditions=None,
        uncond_unet_added_conditions=None
    ):
        loss_dict = {}
        log_dict = {}
        
        if not self.gan_alone:
            dm_dict, dm_log_dict = self.compute_distribution_matching_loss(
                image, text_embedding, uncond_embedding, 
                unet_added_conditions, uncond_unet_added_conditions
            )

            loss_dict.update(dm_dict)
            log_dict.update(dm_log_dict)

        # if self.cls_on_clean_image:
        #     clean_cls_loss_dict = self.compute_generator_clean_cls_loss(image)
        #     loss_dict.update(clean_cls_loss_dict)
        return loss_dict, log_dict 

    def gen_adv_forward(
        self,
        image
    ):
        loss_dict = {}
        log_dict = {}
        
        if self.cls_on_clean_image:
            clean_cls_loss_dict = self.compute_generator_clean_cls_loss(image)
            loss_dict.update(clean_cls_loss_dict)
            
        return loss_dict, log_dict 


    def compute_guidance_clean_cls_loss(
            self, real_image, fake_image, 
        ):
        # real image has been pixel
        pred_realism_on_real, pil_image_real = self.compute_cls_logits_no_grad(real_image.detach())
        pred_realism_on_fake, pil_image_fake = self.compute_cls_logits_no_grad(fake_image.detach())
        
        # pil_image_real.save('real.png')
        # pil_image_fake.save('fake.png')
        # breakpoint()

        log_dict = {
            "pred_realism_on_real": torch.sigmoid(pred_realism_on_real).squeeze(dim=1).detach(),
            "pred_realism_on_fake": torch.sigmoid(pred_realism_on_fake).squeeze(dim=1).detach()
        }

        classification_loss = F.softplus(pred_realism_on_fake).mean() + F.softplus(-pred_realism_on_real).mean() 
        classification_loss = classification_loss / 2
        loss_dict = {
            "guidance_cls_loss": classification_loss
        }
        return loss_dict, log_dict 

    def guidance_forward(
        self,
        image,
        final_image,
        text_embedding,
        uncond_embedding,
        real_train_dict=None,
        unet_added_conditions=None,
        uncond_unet_added_conditions=None
    ):
        fake_dict, fake_log_dict = self.compute_loss_fake(
            image, text_embedding, uncond_embedding,
            unet_added_conditions=unet_added_conditions,
            uncond_unet_added_conditions=uncond_unet_added_conditions
        )

        loss_dict = fake_dict 
        log_dict = fake_log_dict
        
        if self.cls_on_clean_image:
            clean_cls_loss_dict, clean_cls_log_dict = self.compute_guidance_clean_cls_loss(
                real_image=real_train_dict['images'], 
                fake_image=final_image.detach(),
            )
            loss_dict.update(clean_cls_loss_dict)
            log_dict.update(clean_cls_log_dict)

        return loss_dict, log_dict 

    def forward(
        self,
        generator_turn=False,
        guidance_turn=False,
        gen_adv_turn=False,
        generator_data_dict=None, 
        guidance_data_dict=None
    ):    
        if generator_turn:
            loss_dict, log_dict = self.generator_forward(
                image=generator_data_dict["image"],
                text_embedding=generator_data_dict["text_embedding"],
                uncond_embedding=generator_data_dict["uncond_embedding"],
                unet_added_conditions=generator_data_dict["unet_added_conditions"],
                uncond_unet_added_conditions=generator_data_dict["uncond_unet_added_conditions"]
            )   
        elif gen_adv_turn:
            loss_dict, log_dict = self.gen_adv_forward(
                image=generator_data_dict["image"],
            )   
        elif guidance_turn:
            loss_dict, log_dict = self.guidance_forward(
                image=guidance_data_dict["image"],
                final_image=guidance_data_dict["final_image"],
                text_embedding=guidance_data_dict["text_embedding"],
                uncond_embedding=guidance_data_dict["uncond_embedding"],
                real_train_dict=guidance_data_dict["real_train_dict"],
                unet_added_conditions=guidance_data_dict["unet_added_conditions"],
                uncond_unet_added_conditions=guidance_data_dict["uncond_unet_added_conditions"]
            ) 
        else:
            raise NotImplementedError

        return loss_dict, log_dict 