import collections
import os
from torch.utils.checkpoint import checkpoint
import gc
import numpy as np
import logging
import argparse
from pathlib import Path
from typing import Tuple, Dict, List, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from tqdm import tqdm
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel, StableDiffusionPipeline
from transformers import CLIPTokenizer, CLIPTextModel
from torchvision import models
from transformers.agents.python_interpreter import evaluate_python_code
from utils import ReparamModule
from torchvision.utils import save_image
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
import matplotlib.pyplot as plt

# Default configuration (will be overridden by args)
DEFAULT_IMAGE_SIZE = 64
DEFAULT_NUM_TIMESTEPS = 1000
DEFAULT_CHANNELS = 4


class DiffusionReconstructor:
    """Diffusion Model-based Image Reconstructor"""

    def __init__(self, args: argparse.Namespace):
        self.args = args
        self.device = torch.device(args.device)
        self.logger = self._setup_logger()

        # Initialize loss tracking lists
        self.total_losses = []
        self.main_losses = []
        self.diversity_losses = []
        self.epochs = []

        # Initialize model components
        self.vae = AutoencoderKL.from_pretrained(args.model_path, subfolder="vae").to(self.device)
        self.tokenizer = CLIPTokenizer.from_pretrained(args.model_path, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(args.model_path, subfolder="text_encoder").to(self.device)
        self.scheduler = DDIMScheduler(
            num_train_timesteps=DEFAULT_NUM_TIMESTEPS,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear"
        )
        self.model_init, self.model_grad = self._build_models()
        self.text_emb, self.text_emb_ini = self._init_text_embeddings()
        self.dummy_images = self._init_dummy_data()
        self.gaussian_images = self._init_gaussian_data()
        self.timesteps = self._init_timesteps()

        # Setup optimizer with different learning rates for different parameters
        self.optimizer = torch.optim.Adam(
            [
                {"params": [self.dummy_images], "lr": args.lr_img, "betas": (0.8, 0.9)},
                {"params": [self.text_emb], "lr": args.lr_txt, "betas": (0.8, 0.9)},
                {"params": [self.gaussian_images], "lr": args.lr_noi, "betas": (0.8, 0.9)},
                {"params": [self.timesteps], "lr": args.lr_t, "betas": (0.8, 0.9)},
            ],
        )

    def _setup_logger(self) -> logging.Logger:
        """Configure logging system"""
        logger = logging.getLogger("ReconstructionLogger")
        logger.setLevel(logging.DEBUG)

        formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")

        # File handler
        log_dir = Path(self.args.save_dir) / "logs"
        log_dir.mkdir(parents=True, exist_ok=True)
        file_handler = logging.FileHandler(log_dir / "recons.log")
        file_handler.setFormatter(formatter)

        # Console handler
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)
        return logger

    def _build_models(self) -> tuple:
        """Initialize model triplets for reconstruction"""
        print(f"Loading model from {self.args.model_path}")

        # Model initialization
        model_init = UNet2DConditionModel.from_pretrained(self.args.model_path, subfolder="unet")
        ckpt_dir = Path("ckpts/model_ckpt") / self.args.prompt
        model_init.load_state_dict(torch.load(ckpt_dir / f"epoch_{self.args.check_epoch}.pt"))
        model_grad = torch.load(ckpt_dir / f"gradients_{self.args.check_epoch}.pt", map_location=self.device)

        # Set attention processor
        for model in [model_init]:
            model.set_attn_processor(AttnProcessor())

        # Device placement
        model_init.to(self.device).eval()
        model_init = ReparamModule(model_init)

        return model_init, model_grad

    def alpha_bar_cont(self, t: torch.Tensor, beta_start=0.00085, beta_end=0.012):
        """Calculate continuous alpha bar values for diffusion process"""
        k_t = np.sqrt(beta_start) + t / DEFAULT_NUM_TIMESTEPS * (np.sqrt(beta_end) - np.sqrt(beta_start))
        k_0 = torch.zeros_like(k_t) + np.sqrt(beta_start)
        alpha_bar = torch.exp(
            DEFAULT_NUM_TIMESTEPS * (-k_t ** 3 + k_0 ** 3) / (3 * (np.sqrt(beta_end) - np.sqrt(beta_start))))
        return alpha_bar

    def _init_text_embeddings(self) -> tuple:
        """Initialize text embeddings for conditional generation"""
        with torch.no_grad():
            # Initial text embedding
            text_input_ini = self.tokenizer(
                [self.args.prompt],
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                return_tensors="pt"
            )
            embeddings_ini = self.text_encoder(text_input_ini.input_ids.to(self.device))[0]

            # Empty text embedding
            text_input = self.tokenizer(
                [""],
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                return_tensors="pt"
            )
            embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        embeddings.requires_grad_(True)
        return embeddings, embeddings_ini

    def _init_dummy_data(self) -> torch.Tensor:
        """Initialize trainable dummy images"""
        tensor = torch.randn(self.args.batch_size, 3, 512, 512).to(
            self.device) * self.args.init_scale
        with torch.no_grad():
            tensor = self.vae.encode(tensor).latent_dist.sample() * 0.18215
        tensor.requires_grad_(True)

        return tensor

    def _init_gaussian_data(self) -> torch.Tensor:
        """Initialize trainable gaussian noise images"""
        tensor = torch.randn(self.args.batch_size, DEFAULT_CHANNELS, DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE).to(
            self.device)
        tensor.requires_grad_(True)
        return tensor

    def _init_timesteps(self) -> torch.Tensor:
        """Initialize trainable timesteps"""
        tensor = torch.randint(400, 600, (self.args.batch_size,)).to(
            self.device).float() / DEFAULT_NUM_TIMESTEPS
        tensor = torch.log(tensor) - torch.log(1 - tensor)
        tensor.requires_grad_(True)
        return tensor

    def reinit_gaussian_images(self):
        """Re-initialize gaussian images with new random noise"""
        with torch.no_grad():
            new_data = torch.randn_like(self.gaussian_images)
            self.gaussian_images.data.copy_(new_data)
            if self.gaussian_images.grad is not None:
                self.gaussian_images.grad.detach_()
                self.gaussian_images.grad.zero_()

    def add_noise(
            self,
            original_samples: torch.Tensor,
            noise: torch.Tensor,
            timesteps: torch.Tensor,
    ) -> torch.Tensor:
        """Add noise to samples using diffusion process parameters"""
        sqrt_alpha_prod = self.alpha_bar_cont(timesteps) ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - self.alpha_bar_cont(timesteps)) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

    def _compute_gradients(self, params: torch.Tensor) -> Tensor:
        """Compute gradients through unrolled training steps"""
        params.requires_grad_(True)

        # Forward diffusion process
        img = self.dummy_images
        t = torch.sigmoid(self.timesteps) * DEFAULT_NUM_TIMESTEPS
        noise = self.gaussian_images
        noisy_img = self.add_noise(img, noise, t)
        text_emb = self.text_emb.expand(self.args.batch_size, -1, -1)

        pred_noise = self.model_init(noisy_img, t, text_emb, flat_param=params).sample
        loss = F.mse_loss(pred_noise, noise)

        # Backpropagation
        grad = torch.autograd.grad(loss, params, create_graph=True)[0]

        return grad

    def _compute_similarity_loss(self) -> Tensor:
        """Calculate similarity loss between images in the same batch"""
        batch_size = self.dummy_images.size(0)
        if batch_size < 2:
            return None

        # Constrain pixel values and flatten
        clamped_img = torch.sigmoid(self.dummy_images)
        flattened = clamped_img.view(batch_size, -1)

        # L2 normalization and cosine similarity matrix
        flattened = F.normalize(flattened, p=2, dim=-1)
        sim_matrix = torch.mm(flattened, flattened.T)

        # Get upper triangle as pairwise similarities
        triu_indices = torch.triu_indices(batch_size, batch_size, offset=1)
        pair_similarities = sim_matrix[triu_indices[0], triu_indices[1]]

        # Use absolute mean as loss
        similarity_loss = torch.mean(torch.abs(pair_similarities))

        return similarity_loss

    def reconstruction_loss(self, epoch):
        """Calculate composite reconstruction loss"""
        # Compute parameter gradients
        init_params = torch.cat([p.reshape(-1) for p in self.model_init.parameters()], 0).detach()

        total_grad = self._compute_gradients(init_params)

        main_loss = 1 - F.cosine_similarity(total_grad.flatten(), self.model_grad.flatten(), dim=0,
                                            eps=1e-20)

        diversity_loss = self._compute_similarity_loss()

        total_loss = main_loss + self.args.ind_lambda * diversity_loss if epoch <= 100 and diversity_loss is not None else main_loss

        return total_loss, main_loss, diversity_loss

    def plot_loss_curves(self, save_path):
        """Plot and save loss curves"""
        plt.figure(figsize=(10, 6))

        # Plot total loss
        plt.plot(self.epochs, self.total_losses, label='Total Loss', color='blue', linewidth=2)
        # Plot main loss
        plt.plot(self.epochs, self.main_losses, label='Main Loss', color='red', linewidth=2)
        # Plot diversity loss
        plt.plot(self.epochs, self.diversity_losses, label='Diversity Loss', color='green', linewidth=2)

        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss Curves')
        plt.legend()
        plt.grid(True)

        # Save figure
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

        self.logger.info(f"Loss curves saved to {save_path}")

    def train(self):
        """Main training loop"""
        save_dir = Path(self.args.save_dir) / self.args.prompt / f"{self.args.check_epoch}_{self.args.batch_size}_{self.args.reini_time}_{self.args.ind_lambda}"
        save_dir.mkdir(parents=True, exist_ok=True)
        print(f"Start training")

        for epoch in range(self.args.num_epochs):

            if (epoch % self.args.reini_time == 0) and epoch != 0:
                self.reinit_gaussian_images()
                self.logger.info(f"Reinitialized parameters at epoch {epoch}")

            self.optimizer.zero_grad()
            total_loss, main_loss, diversity_loss = self.reconstruction_loss(epoch)
            total_loss.backward()
            self.optimizer.step()

            self.total_losses.append(total_loss.item())
            self.main_losses.append(main_loss.item())
            self.diversity_losses.append(diversity_loss.item() if diversity_loss is not None else 0)
            self.epochs.append(epoch)

            self.logger.info(
                f"Epoch {epoch:04d} | Loss: {total_loss.item():.4f}| Main Loss: {main_loss.item():.4f}| Diversity Loss: {diversity_loss.item() if diversity_loss is not None else 0:.4f}"
            )

            # Save checkpoint
            if epoch % 100 == 0:
                torch.save(self.dummy_images, save_dir / f"dummy_images_epoch{epoch}.pth")

                text_emb_path = save_dir / f"text_emb_epoch{epoch}.pth"
                torch.save(self.text_emb.detach().cpu(), text_emb_path)
                self.logger.info(f"Saved text embeddings to {text_emb_path}")

                img_dir = save_dir / f"epoch_{epoch}_images"
                img_dir.mkdir(exist_ok=True)

                # Save all dummy images
                img = self.dummy_images.detach().clone()
                with torch.no_grad():
                    decoded_images = self.vae.decode(img / 0.18215).sample
                decoded_images = (decoded_images / 2 + 0.5).clamp(0, 1)
                save_image(
                    decoded_images,
                    img_dir / f"recons.png",
                    nrow=1,
                    normalize=False
                )
                self.logger.info(f"Saved reconstructed images to {img_dir}")

                # Save loss curves every 100 epochs
                loss_plot_path = save_dir / f"loss_curves_epoch{epoch}.png"
                self.plot_loss_curves(loss_plot_path)

        # Save final loss curves after training
        final_loss_plot_path = save_dir / "final_loss_curves.png"
        self.plot_loss_curves(final_loss_plot_path)

        # Save loss data as CSV for further analysis
        loss_data = np.column_stack((self.epochs, self.total_losses, self.main_losses, self.diversity_losses))
        np.savetxt(save_dir / "loss_data.csv", loss_data, delimiter=",",
                   header="epoch,total_loss,main_loss,diversity_loss", comments="")


def parse_args() -> argparse.Namespace:
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Diffusion Model-based Image Reconstruction")

    parser.add_argument("--prompt", type=str, default="dog",
                        help="Class label name for reconstruction")
    parser.add_argument("--model_path", type=str, default="sd_model/tiny-sd",
                        help="Path to the pre-trained diffusion model")
    parser.add_argument("--save_dir", type=str, default="ckpts/GradCFG",
                        help="Directory to save all outputs (checkpoints, images, logs)")
    parser.add_argument("--ind_lambda", type=float, default=1e-2,
                        help="Weight for diversity regularization term")
    parser.add_argument("--lr_img", type=float, default=1e-1,
                        help="Learning rate for image parameters")
    parser.add_argument("--lr_txt", type=float, default=1e-3,
                        help="Learning rate for text embedding parameters")
    parser.add_argument("--lr_noi", type=float, default=1e-1,
                        help="Learning rate for noise parameters")
    parser.add_argument("--lr_t", type=float, default=1e-1,
                        help="Learning rate for timestep parameters")
    parser.add_argument("--batch_size", type=int, default=1,
                        help="Number of samples per batch")
    parser.add_argument("--check_epoch", type=int, default=100,
                        help="Check epoch for model gradients")
    parser.add_argument("--reini_time", type=int, default=100,
                        help="Epoch interval for parameter reinitialization")
    parser.add_argument("--init_scale", type=float, default=5e-1,
                        help="Initialization scale for dummy images")
    parser.add_argument("--num_epochs", type=int, default=4001,
                        help="Total number of training epochs")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to use for computation (cuda/cpu)")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    reconstructor = DiffusionReconstructor(args)
    reconstructor.train()