import sys
import os
base_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.extend([
    base_dir,
    os.path.join(base_dir, "data"),   
    os.path.join(base_dir, "data", "datasets"),
    os.path.join(base_dir, "models"),
    os.path.join(base_dir, "models", "latent-diffusion"),
    os.path.join(base_dir, "models", "latent-diffusion", "ldm"),])
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning, message="torch.utils._pytree._register_pytree_node is deprecated")
import datetime
import argparse
import random
import logging
import numpy as np
import torch
from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter
from diffusers import StableDiffusionPipeline
from models.DiffMorpher.model import DiffMorpherPipeline
from utils.util import instantiate_from_config
from explore_aug import ExploreAugPipeline
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
def print(*args, **kwargs):
    logging.info(" ".join(map(str, args)))
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def save_example_images(data, logdir, num_samples=9, image_size=(256, 256)):
    """
    Save example images from the dataset to logdir/example_images/
    Note: Dataset samples are already augmented (if augmentation is enabled)
    Args:
        data: data module with train/val datasets
        logdir: log directory
        num_samples: number of example images to save
        image_size: size to resize images for visualization
    """
    example_dir = os.path.join(logdir, "example_images")
    os.makedirs(example_dir, exist_ok=True)
    
    # Get train dataset
    train_dataset = data.datasets.get('train')
    if train_dataset is None:
        print("No train dataset found, skipping example image saving")
        return
    
    print(f"Saving {num_samples} example images to {example_dir}")
    print("Note: These are augmented samples from the dataset")
    
    indices = np.random.choice(len(train_dataset), min(num_samples, len(train_dataset)), replace=False)
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    fig.suptitle("Dataset Samples (Augmented)", fontsize=16)

    categories = getattr(train_dataset, 'categories', ['cat', 'dog', 'wild'])
    
    for idx, sample_idx in enumerate(indices):
        row = idx // 3
        col = idx % 3
 
        sample_data = train_dataset[sample_idx]
        if len(sample_data) == 3:
            sample_image, label, _ = sample_data
        else:
            sample_image, label = sample_data

        if torch.is_tensor(sample_image):
            if sample_image.min() < 0: 
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
                sample_image = sample_image * std + mean
                sample_image = torch.clamp(sample_image, 0, 1)

            sample_pil = transforms.ToPILImage()(sample_image)
        else:
            sample_pil = sample_image

        sample_pil_resized = sample_pil.resize(image_size, Image.LANCZOS)

        axes[row, col].imshow(sample_pil_resized)
        axes[row, col].set_title(f"Label: {categories[label] if label < len(categories) else label}")
        axes[row, col].axis('off')

        sample_pil_resized.save(os.path.join(example_dir, f"sample_{idx}_{categories[label] if label < len(categories) else label}.png"))

    fig.tight_layout()
    fig.savefig(os.path.join(example_dir, "dataset_samples_grid.png"), dpi=150, bbox_inches='tight')
    plt.close(fig)
    
    print(f"Example images saved to {example_dir}")
    print(f"- dataset_samples_grid.png: Grid of dataset samples (augmented)")
    print(f"- Individual images: sample_*.png")

set_seed(42)

if __name__ == "__main__":
    # configure
    
    parser = argparse.ArgumentParser(description="Load Pretrained Autoencoder and Perform Inference")
    parser.add_argument('--config', type=str, required=True, help="Path to the config YAML file")
    args = parser.parse_args()

    config = OmegaConf.load(args.config)
    if config.date == 'None':
        now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    else :
        now = config.date
    logdir = os.path.join(config.logdir, now)
    os.makedirs(logdir, exist_ok=True)
    num = config.num
    subnum = config.subnum

    logtxt = os.path.join(logdir, "log.txt")
    logging.basicConfig(filename=logtxt, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    logging.getLogger().addHandler(console_handler)
    logging.info(f"Loaded configuration: \n{OmegaConf.to_yaml(config)}")
    writer = SummaryWriter(log_dir=os.path.join(logdir, "tensorboard", num, subnum))
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    sd_path = "saved_models/Stable-Diffusion/stable-diffusion-2-1-base"

    model_index_path = os.path.join(sd_path, "model_index.json")
    if os.path.exists(model_index_path):
        print(f"Found local Stable Diffusion model at: {sd_path}")
        try:
            dm_pipeline = DiffMorpherPipeline.from_pretrained(sd_path, torch_dtype=torch.float32).to(device)
            print("DiffMorpher pipeline loaded successfully from local model")
        except Exception as e:
            print(f"Failed to load local model: {e}")
            print("Falling back to online model download...")
            sd_path = "stabilityai/stable-diffusion-2-1-base"
            dm_pipeline = DiffMorpherPipeline.from_pretrained(sd_path, torch_dtype=torch.float32).to(device)
            print("DiffMorpher pipeline loaded from online model")
    # else:
    #     print(f"Local model not found at: {sd_path}")
    #     print("Attempting to download from online...")
    #     sd_path = "stabilityai/stable-diffusion-2-1-base"
    #     dm_pipeline = DiffMorpherPipeline.from_pretrained(sd_path, torch_dtype=torch.float32).to(device)
    #     print("DiffMorpher pipeline loaded from online model")

    # create or load saved model: resnetclassifier and loss、optimizer
    classifier = instantiate_from_config(config.model.classifier).to(device)

    data = instantiate_from_config(config.data)
    data.prepare_data()
    data.setup()
    print("#### Data #####")
    for k in data.datasets:
        print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")

    # Save example images after data setup
    try:
        save_example_images(data, logdir, num_samples=9)
    except Exception as e:
        print(f"Warning: Failed to save example images: {e}")

    pipeline = ExploreAugPipeline(
        classifier=classifier,
        dm=dm_pipeline,
        data=data,
        config=config,
        writer=writer,
        logdir=logdir,
        num=num,           
        subnum=subnum,         
        device=device
    )
    pipeline.run()
