import os
import sys
from argparse import ArgumentParser
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from PIL import Image
import io
from datasets import load_dataset
from reward_distill.model.clip_mlp_model import MLP
from reward_models.reward_interface import UnifiedReward
from customize_scheduler.customize_euler import CustomEuler
from customize_pipeline.custom_sdxl_pipeline import CustomizeStableDiffusionXLPipeline
from tqdm import tqdm
import clip
from accelerate import Accelerator
from torch.nn import MSELoss
from torch.utils.data import DataLoader


import ipdb

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    accelerator = Accelerator()
    # meta-parameters
    device = torch.device(args.device)

    ground_truth = UnifiedReward("pickscore")

    # dataset = load_dataset("THUDM/ImageRewardDB", "1k")
    # dataset = load_dataset("yuvalkirstain/pickapic_v1", num_proc=1)

    # Function to load the preprocessed dataset
    def load_processed_dataset(save_path):
        if os.path.exists(save_path):
            print(f"Loading preprocessed dataset from {save_path}")
            return torch.load(save_path)
        else:
            raise FileNotFoundError(f"{save_path} not found. Please preprocess the dataset first.")

    # load processed dataset
    save_path = "processed_dataset.pt"
    processed_dataset = load_processed_dataset(save_path)

    # Load the stable diffusion pipeline
    pipeline = CustomizeStableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo").to(accelerator.device)
    vae = pipeline.vae  # Get the VAE model
    unet = pipeline.unet
    text_encoder = pipeline.text_encoder
    tokenizer = pipeline.tokenizer
    pipeline.scheduler = CustomEuler.from_config(
        pipeline.scheduler.config,
    )

    # Preprocess the Input Image: Convert the input image to the expected format (e.g., normalized pixel values, resized dimensions).

    # Preprocessing for the input image
    # preprocess = Compose([
    #     Resize((512, 512)),  # Resize to model input size
    #     CenterCrop(512),
    #     ToTensor(),
    #     Normalize([0.5], [0.5])  # Normalize to [-1, 1] range
    # ])

    model = MLP(768, 10)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    unet, optimizer = accelerator.prepare(unet, optimizer)
    mse_loss = MSELoss()

    # def collate_fn(batch):
    #     image_0_list, image_1_list, label_list, prompt_list, score_0_list, score_1_list = [], [], [], [], [], []
    #     for data in batch:
    #         image_bytes_0 = data['jpg_0']
    #         image_bytes_1 = data['jpg_1']
    #         prompt = data['caption']
    #         label = data['label_0']

    #         # Process image 0
    #         image_0 = Image.open(io.BytesIO(image_bytes_0)).convert("RGB")
    #         processed_image_0 = preprocess(image_0)
    #         image_0_list.append(processed_image_0)
    #         score_0 = ground_truth.score(ToTensor()(image_0), prompt)
    #         score_0_list.append(score_0)

    #         # Process image 1
    #         image_1 = Image.open(io.BytesIO(image_bytes_1)).convert("RGB")
    #         processed_image_1 = preprocess(image_1)
    #         image_1_list.append(processed_image_1)
    #         score_1 = ground_truth.score(ToTensor()(image_1), prompt)
    #         score_1_list.append(score_1)

    #         label_list.append(bool(label))
    #         prompt_list.append(prompt)

    #     return (
    #         torch.stack(image_0_list),
    #         torch.stack(image_1_list),
    #         label_list,
    #         prompt_list,
    #         torch.tensor(score_0_list),
    #         torch.tensor(score_1_list),
    #     )

    # Use DataLoader for batching
    # train_dataloader = DataLoader(dataset['train'], batch_size=16, shuffle=True, collate_fn=collate_fn)


    def collate_fn(batch):
        image_0_list = torch.stack([item["image_0"] for item in batch])
        image_1_list = torch.stack([item["image_1"] for item in batch])
        label_list = [item["label"] for item in batch]
        prompt_list = [item["prompt"] for item in batch]
        scores_0 = torch.stack([item["score_0"] for item in batch])
        scores_1 = torch.stack([item["score_1"] for item in batch])
        return image_0_list, image_1_list, label_list, prompt_list, scores_0, scores_1

    train_dataloader = DataLoader(processed_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)


    for epoch in range(1):
        for image_0_batch, image_1_batch, labels, prompts, scores_0, scores_1 in tqdm(train_dataloader):
            # Move images to accelerator device
            image_0_batch = image_0_batch.to(accelerator.device)
            image_1_batch = image_1_batch.to(accelerator.device)
            scores_0 = scores_0.to(accelerator.device)
            scores_1 = scores_1.to(accelerator.device)

            # image_bytes_0 = dataset['train'][i]['jpg_0']
            # image_bytes_1 = dataset['train'][i]['jpg_1']
            # image_0 = Image.open(io.BytesIO(image_bytes_0))
            # image_0_before_normalization = image_0.convert("RGB")
            # image_0 = preprocess(image_0_before_normalization).unsqueeze(0).to(accelerator.device)  # Add batch dimension and move to GPU
            # image_1 = Image.open(io.BytesIO(image_bytes_1))
            # image_1_before_normalization = image_1.convert("RGB")
            # image_1 = preprocess(image_1_before_normalization).unsqueeze(0).to(accelerator.device)  # Add batch dimension and move to GPU

            # prompts = dataset['train'][i]['caption']

            # with torch.no_grad():
            #     latents_0 = vae.encode(image_0.to(accelerator.device)).latent_dist.sample() * vae.config.scaling_factor
            #     latents_1 = vae.encode(image_1.to(accelerator.device)).latent_dist.sample() * vae.config.scaling_factor

            with torch.no_grad():
                # Latent encoding
                latents_0_batch = vae.encode(image_0_batch).latent_dist.sample() * vae.config.scaling_factor
                latents_1_batch = vae.encode(image_1_batch).latent_dist.sample() * vae.config.scaling_factor


            # select the wining response
            # if bool(dataset['train'][i]['label_0']):
            #     wining_response = latents_0
            # else:
            #     wining_response = latents_1

            # Select winning responses
            winning_responses = torch.stack([latents_0 if label else latents_1 for latents_0, latents_1, label in zip(latents_0_batch, latents_1_batch, labels)])


            text_inputs = tokenizer(prompts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
            text_embeddings = text_encoder(text_inputs.input_ids.to(accelerator.device))[0]

            # MSE loss for distilling pickscore
            # score_0 = ground_truth.score(ToTensor()(image_0_before_normalization), prompts).to(accelerator.device)
            # score_1 = ground_truth.score(ToTensor()(image_1_before_normalization), prompts).to(accelerator.device)

            # Noise addition and prediction for score matching loss
            timesteps_0 = torch.randint(0, pipeline.scheduler.config.num_train_timesteps, (latents_0_batch.size(0),), device=latents_0_batch.device)
            noisy_latents_0 = pipeline.scheduler.add_noise(latents_0_batch, torch.randn_like(latents_0_batch), timesteps_0)
            predicted_scores_0 = model(noisy_latents_0, timesteps_0)

            # timesteps_0 = torch.randint(0, pipeline.scheduler.config.num_train_timesteps, (latents_0.size(0),), device=latents_0.device)
            # noise_0 = torch.randn_like(latents_0)
            # noisy_latents_0 = pipeline.scheduler.add_noise(latents_0, noise_0, timesteps_0)
            # predicted_score_0 = model(noisy_latents_0, timesteps_0)

            timesteps_1 = torch.randint(0, pipeline.scheduler.config.num_train_timesteps, (latents_1_batch.size(0),), device=latents_1_batch.device)
            noisy_latents_1 = pipeline.scheduler.add_noise(latents_1_batch, torch.randn_like(latents_1_batch), timesteps_1)
            predicted_scores_1 = model(noisy_latents_1, timesteps_1)

            # timesteps_1 = torch.randint(0, pipeline.scheduler.config.num_train_timesteps, (latents_1.size(0),), device=latents_1.device)
            # noise_1 = torch.randn_like(latents_1)
            # noisy_latents_1 = pipeline.scheduler.add_noise(latents_1, noise_1, timesteps_1)
            # predicted_score_1 = model(noisy_latents_1, timesteps_1)
            # loss = 0.5 * mse_loss(predicted_score_0, score_0) + 0.5 * mse_loss(predicted_score_1, score_1)
            loss = 0.5 * mse_loss(predicted_scores_0, scores_0) + 0.5 * mse_loss(predicted_scores_1, scores_1)

            # score matching loss for regularization
            # Generate oracle noise and noisy latents
            # timesteps = torch.randint(0, pipeline.scheduler.config.num_train_timesteps, (wining_response.size(0),), device=wining_response.device)
            # noise = torch.randn_like(wining_response)
            # noisy_latents = pipeline.scheduler.add_noise(wining_response, noise, timesteps)
            timesteps = torch.randint(0, pipeline.scheduler.config.num_train_timesteps, (winning_responses.size(0),), device=winning_responses.device)
            noise = torch.randn_like(winning_responses)
            noisy_latents = pipeline.scheduler.add_noise(winning_responses, noise, timesteps)
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
            loss += mse_loss(model_pred, noise) # we can add one more hyperparameter here to balance two loss

            # Forward pass through UNet
            # model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample

            # Compute loss
            # loss += mse_loss(model_pred, noise) # we can add one more hyperparameter here to balance two loss

            # Backward pass
            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()

        # Logging
        if epoch % 10 == 0:
            accelerator.print(f"Epoch [{epoch+1}], Loss: {loss.item():.4f}")

        if epoch % 20 == 0:
            accelerator.save_state(f"checkpoint_epoch_{epoch+1}")



