import os
import copy
import torch
from torch import nn
from torchvision.transforms import v2
from torch.nn.functional import binary_cross_entropy_with_logits
import torch.nn.functional as F
from diffusers import (AutoencoderKL, UNet2DConditionModel, StableDiffusionPipeline, FluxPipeline)
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from peft.utils import get_peft_model_state_dict
from diffusers.utils import convert_state_dict_to_diffusers
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from peft import LoraConfig
from PIL import Image
from .utils import str2torch_dtype, cast_training_params
from .model_opt import get_train_tuple


class StableDiffision(torch.nn.Module):
    """Base class for stable diffusion models with image segmentation capabilities.
    Initialize the StableDiffision model with diffusion, VAE, and LoRA configurations.

    Args:
        config_diffusion: Configuration dictionary for the diffusion model.
        config_vae: Configuration dictionary for the VAE model.
        config_lora: Configuration dictionary for LoRA adaptations.
    """

    def __init__(self,
                 config_diffusion: dict,
                 config_vae: dict = None,
                 config_lora: dict = None,
                 ):

        super().__init__()
        self.latent_diffusion_model = None
        self.tokenizer = None
        self.text_encoder = None
        self.vae = None
        self.lora = None
        self.head = None
        self.trainable_params = []
        self.miou_thresh = 0.5
        if torch.cuda.device_count() == 1:
            self.device = "cuda"
            self.weight_dtype = torch.float16
        else:
            accelerator = Accelerator()
            self.device = copy.deepcopy(accelerator.device)
            self.weight_dtype = str2torch_dtype(accelerator.mixed_precision, torch.float16)
            del accelerator

        self.config_diffusion = copy.deepcopy(config_diffusion)
        self.config_vae = copy.deepcopy(config_vae)
        self.config_lora = copy.deepcopy(config_lora)

        self.init_diffusion(self.config_diffusion)
        self.init_vae(self.config_vae)
        if self.config_lora:
            self.init_lora(self.config_lora)
        assert len(self.trainable_params) > 0, "No trainable parameters"

    def init_diffusion(self, config):
        """Initialize the diffusion model components including tokenizer and text encoder.

        Args:
            config: Configuration dictionary for setting up diffusion components.
        """
        pretrained_model_name_or_path = config["pretrained_model_name_or_path"]
        self.tokenizer = CLIPTokenizer.from_pretrained(
            pretrained_model_name_or_path, subfolder="tokenizer"
        )
        self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
        self.latent_diffusion_model = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="unet")

        if "unet_dtype" not in config:
            config["unet_dtype"] = "fp16"
        self.latent_diffusion_model.to(self.device, str2torch_dtype(config["unet_dtype"], default=self.weight_dtype))

        if "text_encoder_dtype" not in config:
            config["text_encoder_dtype"] = config["unet_dtype"]
        self.text_encoder.to(self.device, str2torch_dtype(config["text_encoder_dtype"], default=self.weight_dtype))

        # freeze unet
        if "train_unet" not in config:
            config["train_unet"] = False
        if not config["train_unet"]:
            print("freeze unet")
            self.latent_diffusion_model.requires_grad_(False)
        else:
            if config.get("enable_gradient_checkpointing", False):
                self.latent_diffusion_model.enable_gradient_checkpointing()
            self.trainable_params.extend(cast_training_params(self.latent_diffusion_model))

        self.text_encoder.requires_grad_(False)

    def init_vae(self, config):
        """Initialize the VAE model for image encoding and decoding.

        Args:
            config: Configuration dictionary for setting up the VAE model.
        """
        if config is None:
            config = {}
        if "pretrained_vae_name_or_path" in config:
            pretrained_vae_name_or_path = config["pretrained_vae_name_or_path"]
        else:
            pretrained_vae_name_or_path = self.config_diffusion["pretrained_model_name_or_path"]
            config["pretrained_vae_name_or_path"] = pretrained_vae_name_or_path
            self.config_vae = config
        self.vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path, subfolder="vae")
        if "vae_dtype" not in config:
            config["vae_dtype"] = None
        self.vae.to(self.device, str2torch_dtype(config["vae_dtype"], default=self.weight_dtype))
        self.vae.requires_grad_(False)

    def init_lora(self, config):
        """Initialize LoRA adapters for the diffusion model.

        Args:
            config: Configuration dictionary for setting up LoRA adaptations.
        """
        if "train_unet" in self.config_diffusion and self.config_diffusion["train_unet"]:
            raise ValueError()
        if config.get("enable_gradient_checkpointing", False):
            self.latent_diffusion_model.enable_gradient_checkpointing()
        rank = config["rank"]
        unet_lora_config = LoraConfig(
            r=rank,
            lora_alpha=rank,
            init_lora_weights="gaussian",
            target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        )
        self.latent_diffusion_model.add_adapter(unet_lora_config)
        unet_lora_parameters = cast_training_params(self.latent_diffusion_model)
        # self.latent_diffusion_model.enable_gradient_checkpointing()
        self.lora = unet_lora_parameters
        self.trainable_params.extend(unet_lora_parameters)

    def forward(self, batch):
        """Forward pass of the model for training. Encodes images and masks, then runs diffusion process.

        Args:
            batch: Dictionary containing image, mask, and prompt data for training.

        Returns:
            Tuple containing the computed loss and a dictionary of metrics.
        """
        latents = self.run_vae(batch["image"])
        mask = batch["mask"]
        z1 = self.run_vae(torch.cat([mask, mask, mask], 1))
        z1_refine = None
        if "mask_refine" in batch:
            mask_refine = batch["mask_refine"]
            z1_refine = self.run_vae(torch.cat([mask_refine, mask_refine, mask_refine], 1))
            z_t_refine, t_refine, v_refine = get_train_tuple(latents, z1_refine)

        # bs, channel, height, width = mask.shape
        # timesteps = torch.full((bs,), 500, device=self.device)
        z_t, t, v = get_train_tuple(latents, z1)
        t = (1 - t).to(dtype=t.dtype)
        # t = torch.round(t * 1000, 0)
        timesteps = (999 * t).long().reshape(-1)
        # Get the text embedding for conditioning
        input_ids = self.tokenizer(
            batch["prompt"], max_length=self.tokenizer.model_max_length,
            padding="max_length", truncation=True, return_tensors="pt"
        ).input_ids
        encoder_hidden_states = self.text_encoder(input_ids.to(self.device))[0]
        model_pred = self.run_diffusion_model(z_t, timesteps, encoder_hidden_states)

        loss_arr = ((model_pred - v) ** 2).mean([1, 2, 3])
        indices = None
        if z1_refine is not None:
            loss_arr_refine = ((model_pred - v_refine) ** 2).mean([1, 2, 3])
            loss_arr = torch.min(loss_arr, loss_arr_refine)
            indices = torch.where(loss_arr < loss_arr_refine, 0, 1)
        loss = loss_arr.mean()
        t_list = (timesteps // 200).cpu().numpy().tolist()
        metric_dict = {f"loss_{t * 200}~{(t + 1) * 200}": loss_arr[i].cpu().item() for i, t in enumerate(t_list)}
        if indices is not None:
            metric_dict["pick_mask"] = (indices == 0).sum().cpu().item()
            metric_dict["pick_mask_refine"] = (indices == 1).sum().cpu().item()

        return loss, metric_dict

    @torch.no_grad()
    def evaluate(self, batch, keep_dim=False, num_samples=1, only_pred=False, step_enhance=False):
        """Evaluate the model on a batch of data without gradient computation.

        Args:
            batch: Dictionary containing image, mask, and prompt data for evaluation.
            keep_dim: Whether to keep the dimensions of the computed mIoU.
            num_samples: Number of samples to generate for evaluation.
            only_pred: Whether to return only the predicted masks.
            step_enhance: Whether to use step enhancement during sampling.

        Returns:
            Tuple containing the computed loss and a dictionary of metrics, or
            just the predicted masks if only_pred is True.
        """
        latents = self.run_vae(batch["image"])
        mask = batch["mask"]
        z1 = self.run_vae(torch.cat([mask, mask, mask], 1))

        input_ids = self.tokenizer(
            batch["prompt"], max_length=self.tokenizer.model_max_length,
            padding="max_length", truncation=True, return_tensors="pt"
        ).input_ids
        encoder_hidden_states = self.text_encoder(input_ids.to(self.device))[0]
        if step_enhance:
            z1_zero = self.run_vae(torch.zeros_like(batch["image"][:1]))
            output_old, output = self.sample_ode_step1(latents, encoder_hidden_states, z1_zero)
        else:
            output = self.sample_ode(latents, encoder_hidden_states, num_samples)
        loss = ((output - z1) ** 2).mean()
        output = output / self.vae.config.scaling_factor
        bs = len(output)
        if bs > 8:
            output_0 = self.vae.decode(output[:bs//2].to(dtype=self.vae.dtype), return_dict=False)[0][:, :1]
            output_1 = self.vae.decode(output[bs//2:].to(dtype=self.vae.dtype), return_dict=False)[0][:, :1]
            prob = torch.cat([output_0, output_1], 0)
        else:
            prob = self.vae.decode(output.to(dtype=self.vae.dtype), return_dict=False)[0][:, :1]
        if only_pred:
            return prob
        miou = self.run_miou(prob, mask, keep_dim=keep_dim)
        return loss, {"miou": miou}


    @torch.no_grad()
    def sample_ode(self, z0, encoder_hidden_states, sample_steps=None):
        """Generate samples using an ODE-based sampling process without gradient computation.

        Args:
            z0: Initial latent vectors.
            encoder_hidden_states: Text encoder hidden states for conditioning.
            sample_steps: Number of sampling steps to perform.

        Returns:
            Generated latent vectors after the ODE sampling process.
        """
        if sample_steps is None:
            sample_steps = self.T
        dt = 1. / sample_steps
        z = z0.detach().clone()
        batch_size = z.shape[0]
        for i in range(sample_steps):
            t = (torch.full((batch_size, ), i, device=z.device) / sample_steps * 999).long()
            t = 999 - t
            pred = self.run_diffusion_model(z, t, encoder_hidden_states)
            z = z.detach().clone() + pred * dt
        return z

    @torch.no_grad()
    def sample_ode_step1(self, z0, encoder_hidden_states, z1_zero):
        """Perform a specialized single-step ODE sampling with enhancement.

        Args:
            z0: Initial latent vectors.
            encoder_hidden_states: Text encoder hidden states for conditioning.
            z1_zero: Zero-mask reference latent vector.

        Returns:
            Tuple containing the intermediate and final generated latent vectors.
        """
        z = z0.detach().clone()
        batch_size = z.shape[0]
        t = torch.zeros((batch_size, ), device=z.device).long()
        t = 999 - t
        pred = self.run_diffusion_model(z, t, encoder_hidden_states)
        z_prev = z.detach().clone() + pred
        z1_zero = z1_zero[0]
        z_final_arr = []
        for i in range(len(z_prev)):
            dist = ((z_prev[i] - z1_zero) ** 2).mean(0)
            row, col = torch.unravel_index(torch.argmin(dist), dist.shape)
            mask = (dist - dist.min()) < 0.01
            delta_v = (z1_zero[..., mask] - z_prev[i][..., mask]).abs().mean()
            length_v = pred[0][..., mask].abs().mean()
            ratio = delta_v / length_v
            # print(f"delta_v={delta_v}, length_v={length_v}, ratio={ratio}")
            z_final = z.detach().clone()[i] + pred[i] * (1 + ratio)
            z_final_arr.append(z_final)
        z_final = torch.stack(z_final_arr, 0)
        return z_prev, z_final

    def run_vae(self, pixel_values):
        """Run the VAE encoder on input pixel values.

        Args:
            pixel_values: Input image tensor in pixel space.

        Returns:
            Encoded latent vectors in the VAE latent space.
        """
        latents = self.vae.encode(pixel_values.to(dtype=self.vae.dtype)).latent_dist.mean
        latents = latents * self.vae.config.scaling_factor
        latents = latents.to(dtype=self.latent_diffusion_model.dtype)
        return latents

    def run_diffusion_model(
            self, noisy_latents, timesteps, encoder_hidden_states
    ):
        """Run the diffusion UNet model on noisy latents with conditioning.

        Args:
            noisy_latents: Noisy latent vectors to denoise.
            timesteps: Timesteps for the diffusion process.
            encoder_hidden_states: Text encoder hidden states for conditioning.

        Returns:
            Model predictions of noise residuals.
        """
        # Predict the noise residual and compute loss
        encoder_hidden_states = encoder_hidden_states.to(dtype=self.latent_diffusion_model.dtype)

        model_pred = self.latent_diffusion_model(
            noisy_latents,
            timesteps,
            encoder_hidden_states,
        ).sample
        return model_pred

    def run_loss(self, model_pred, target):
        """Compute binary cross-entropy loss between predictions and targets.

        Args:
            model_pred: Model predictions (logits).
            target: Target values (ground truth).

        Returns:
            Computed binary cross-entropy loss.
        """
        loss = binary_cross_entropy_with_logits(model_pred.float(), target.float())
        return loss

    @torch.no_grad()
    def run_miou(self, model_pred, target, keep_dim=False):
        """Compute mean Intersection over Union (mIoU) metric without gradient computation.

        Args:
            model_pred: Model predictions (probabilities).
            target: Target segmentation masks.
            keep_dim: Whether to keep the dimensions of the computed mIoU.

        Returns:
            Computed mIoU metric as a scalar or array depending on keep_dim.
        """
        pred = model_pred > self.miou_thresh
        gt = target > 0.5
        inter = torch.logical_and(pred, gt).int().sum([1, 2, 3])
        union = torch.logical_or(pred, gt).int().sum([1, 2, 3])
        miou = (inter.float() / (union.float() + 1e-3)).cpu()
        if not keep_dim:
            miou = miou.mean().item()
        else:
            miou = miou.numpy()
        return miou

    def save_weight(self, epoch, ckpt_dir):
        """Save model weights to the specified checkpoint directory.

        Args:
            epoch: Current training epoch number.
            ckpt_dir: Directory path to save checkpoints.
        """
        try:
            epoch = int(epoch)
            epoch = f"{epoch:02d}"
        except:
            pass
        if self.config_lora:
            lora_state_dict = convert_state_dict_to_diffusers(
                get_peft_model_state_dict(self.latent_diffusion_model)
            )
            FluxPipeline.save_lora_weights(
                save_directory=ckpt_dir,
                transformer_lora_layers=lora_state_dict,
                safe_serialization=True,
                weight_name=f"{epoch}_lora.safetensors",
            )
        else:
            self.latent_diffusion_model.save_pretrained(os.path.join(ckpt_dir, f"{epoch}_unet"))


class SDInference(StableDiffision):
    """Inference-specific class for stable diffusion model with segmentation capabilities.

    Args:
        pretrained_model_name_or_path: Name or path of the pretrained model.
        lora_path: Path to the LoRA adapter weights.
    """

    def __init__(self, pretrained_model_name_or_path, lora_path):

        torch.nn.Module.__init__(self)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
        self.latent_diffusion_model = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="unet")
        self.latent_diffusion_model.load_lora_adapter(lora_path)
        self.vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
        for model in [self.text_encoder, self.latent_diffusion_model, self.vae]:
            model.requires_grad_(False)
            model.to(self.device, torch.float16)
        self.tokenizer = CLIPTokenizer.from_pretrained(
            pretrained_model_name_or_path, subfolder="tokenizer"
        )
        self.transform = v2.Compose(
            [
                v2.Resize(512),
                # v2.RandomHorizontalFlip(),
                v2.ToTensor(),
            ]
        )

        self.miou_thresh = 0.5

    @torch.no_grad()
    def predict(self, image: Image.Image, prompt: str, num_steps):
        """Generate segmentation predictions for a single image with a text prompt.

        Args:
            image: Input PIL Image to segment.
            prompt: Text prompt describing the segmentation target.
            num_steps: Number of sampling steps to perform.

        Returns:
            Predicted segmentation mask tensor with the original image dimensions.
        """
        ow, oh = image.size
        x = self.transform(image).unsqueeze(0).to(self.device)
        input_ids = self.tokenizer(
            prompt, max_length=self.tokenizer.model_max_length,
            padding="max_length", truncation=True, return_tensors="pt"
        ).input_ids
        bs, channel, height, width = x.shape
        encoder_hidden_states = self.text_encoder(input_ids.to(self.device))[0]
        latents = self.run_vae(x)
        output = self.sample_ode(latents, encoder_hidden_states, num_steps)
        output = output / self.vae.config.scaling_factor
        mask = self.vae.decode(output.to(dtype=self.vae.dtype), return_dict=False)[0]
        mask = F.interpolate(mask, size=(oh, ow), mode="bilinear")[:, :1]
        return mask
