# Convert resnetencoder to vae.
import os, sys
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.extend([
    base_dir,
    os.path.join(base_dir, "models"),
    os.path.join(base_dir, "models", "latent-diffusion"),
    os.path.join(base_dir, "models", "latent-diffusion", "ldm"),
    os.path.join(base_dir, "utils"),])

import torch
import torch.nn as nn
from diffusers import AutoencoderKL
from torch.optim import Adam
from tqdm import tqdm
import datetime
import logging
from PIL import Image
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from omegaconf import OmegaConf
from utils.util import instantiate_from_config
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, Dataset, Subset
from functools import partial

class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""

    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class DataModuleFromConfig(pl.LightningDataModule):
    def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
                 wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
                 shuffle_val_dataloader=False):
        super().__init__()
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size * 2
        self.use_worker_init_fn = use_worker_init_fn
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
        if predict is not None:
            self.dataset_configs["predict"] = predict
            self.predict_dataloader = self._predict_dataloader
        self.wrap = wrap

    def prepare_data(self):
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
        self.datasets = dict(
            (k, instantiate_from_config(self.dataset_configs[k]))
            for k in self.dataset_configs)
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])
        self.dataloaders = {}
        for key, dataset in self.datasets.items():
            self.dataloaders[key] = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
    
    def _train_dataloader(self):
        return DataLoader(self.datasets["train"], batch_size=self.batch_size, shuffle=False)

    def _val_dataloader(self):
        return DataLoader(self.datasets["validation"],
                          batch_size=self.batch_size,)

    def _test_dataloader(self):
        return DataLoader(self.datasets["test"], batch_size=self.batch_size)


class LatentSpaceConverter(nn.Module):

    def __init__(self, 
                 ckpt_path=None,
                 input_channels=3,   # resnetencoder latent channels
                 output_channels=4,   # SD VAE latent channels
                 hidden_dim=256):
        super().__init__()
        self.ckpt_path = ckpt_path
        self.main = nn.Sequential(
            nn.Conv2d(input_channels, hidden_dim, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim//2, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim//2, output_channels*2, 3, padding=1) 
        )
        if ckpt_path is not None:
            self.load_state_dict(torch.load(ckpt_path))
            print(f"Restored from {ckpt_path}")
        
    def forward(self, x):
        return self.main(x)

def align_latent_spaces(
    custom_vae,       
    sd_vae,           
    train_loader,     
    converter=None,
    device="mps",
    lr=1e-4,
    epochs=50
):

    # converter = LatentSpaceConverter().to(device)
    if converter is None:
        return
    for param in sd_vae.parameters():
        param.requires_grad_(False)
    custom_vae.eval().to(device)
    
    optimizer = Adam(converter.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        pbar = tqdm(train_loader)
        for batch in pbar:
            images = batch.permute(0, 3, 1, 2).to(device)
            
            with torch.no_grad():

                custom_posterior = custom_vae.encode(images)
                custom_latent = custom_posterior.mode()  

                sd_posterior = sd_vae.encode(images).latent_dist
                sd_mean, sd_logvar = sd_posterior.mean, sd_posterior.logvar
            
            converted = converter(custom_latent)
            converted_mean, converted_logvar = converted.chunk(2, dim=1)
            
            mean_loss = criterion(converted_mean, sd_mean)
            logvar_loss = criterion(converted_logvar, sd_logvar)
            total_loss = mean_loss + 0.5*logvar_loss  

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            pbar.set_description(f"Epoch {epoch+1} Loss: {total_loss.item():.4f}")

    return converter

def save_images(images, save_dir, prefix, norm_range=(-1, 1)):

    os.makedirs(save_dir, exist_ok=True)
    
    if norm_range == (-1, 1):
        images = (images * 0.5 + 0.5).clamp(0, 1)  # [-1,1] -> [0,1]
    images = images.cpu().numpy().transpose(0, 2, 3, 1)  # BCHW -> BHWC
    
    for i, img in enumerate(images):
        img = (img * 255).astype(np.uint8)
        Image.fromarray(img).save(os.path.join(save_dir, f"{prefix}_{i}.png"))

def visualize_comparison(
    custom_vae,
    converter,
    sd_vae,
    dataloader,
    device="mps",
    save_dir="results",
    num_batches=3
):

    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= num_batches:
                break
                
            orig_images = batch.permute(0, 3, 1, 2).to(device)
            save_images(orig_images, os.path.join(save_dir, "original"), f"batch_{batch_idx}")

            custom_posterior = custom_vae.encode(orig_images)
            custom_latent = custom_posterior.mode()
            converted = converter(custom_latent)
            converted_mean, _ = converted.chunk(2, dim=1)
            recon_converted = sd_vae.decode(converted_mean).sample
            save_images(recon_converted, os.path.join(save_dir, "converted"), f"batch_{batch_idx}")

            sd_posterior = sd_vae.encode(orig_images)
            sd_latent = sd_posterior.latent_dist.mode()
            recon_sd = sd_vae.decode(sd_latent).sample
            save_images(recon_sd, os.path.join(save_dir, "sd_vae"), f"batch_{batch_idx}")

if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser(description="Load Pretrained Autoencoder and Perform Inference")
    parser.add_argument('--config', type=str, required=True, help="Path to the config YAML file")
    args = parser.parse_args()
    config = OmegaConf.load(args.config)
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    logdir = os.path.join(config.logdir, now)
    os.makedirs(logdir, exist_ok=True)

    logtxt = os.path.join(logdir, "log.txt")
    logging.basicConfig(filename=logtxt, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    logging.getLogger().addHandler(console_handler)
    logging.info(f"Loaded configuration: \n{OmegaConf.to_yaml(config)}")

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")


    custom_vae = instantiate_from_config(config.model)
    custom_vae.to(device).eval()

    data = instantiate_from_config(config.data)
    data.prepare_data()
    data.setup()
    print("#### Data #####")
    for k in data.datasets:
        print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
    
    sd_vae_path="stabilityai/stable-diffusion-2-1-base"
    sd_vae = AutoencoderKL.from_pretrained(sd_vae_path, subfolder="vae").to(device)

    converter = align_latent_spaces(
        custom_vae=custom_vae,
        sd_vae=sd_vae,
        train_loader=data.dataloaders["train"],
        device=device,
        epochs=200
    )
    torch.save(converter.state_dict(), os.path.join(logdir, "converter.pth"))
    print("Converter saved successfully.")

    visualize_comparison(
        custom_vae=custom_vae,
        converter=converter,
        sd_vae=sd_vae,
        dataloader=data.dataloaders["validation"],
        device=device,
        save_dir=os.path.join(logdir,"alignment_results")
    )
    print("Visualization complete.")
