import os
import os.path as osp
import random

import numpy as np
import PIL.Image as PImage
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from models import build_vae_var
import lpips
import wandb
import torchvision
import copy

def setup_environment(gpu_id, seed):
    """Configure environment for the script."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.set_device(gpu_id)
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device


def load_models(device, model_depth):
    """Download checkpoints and load VAE and VAR models."""
    data_dir = "[VAR_MODEL_PATH]"  # Update this path if needed
    hf_home = "https://huggingface.co/FoundationVision/var/resolve/main"

    # Download VAE checkpoint
    vae_ckpt = "vae_ch160v4096z32.pth"
    vae_ckpt_path = osp.join(data_dir, vae_ckpt)
    if not osp.exists(vae_ckpt_path):
        print("Downloading VAE checkpoint...")
        os.makedirs(data_dir, exist_ok=True)
        os.system(f"wget -P {data_dir} {hf_home}/{vae_ckpt}")

    # Download VAR checkpoint
    var_ckpt = f"var_d{model_depth}.pth"
    var_ckpt_path = osp.join(data_dir, var_ckpt)
    if not osp.exists(var_ckpt_path):
        print("Downloading VAR checkpoint...")
        os.makedirs(data_dir, exist_ok=True)
        os.system(f"wget -P {data_dir} {hf_home}/{var_ckpt}")

    patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
    vae, var = build_vae_var(
        V=4096,
        Cvae=32,
        ch=160,
        share_quant_resi=4,  # hard-coded VQVAE hyperparameters
        device=device,
        patch_nums=patch_nums,
        num_classes=1000,
        depth=model_depth,
        shared_aln=False,
    )

    vae.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"), strict=True)
    var.load_state_dict(torch.load(var_ckpt_path, map_location="cpu"), strict=True)
    print("VAE and VAR models loaded successfully.")

    # Set VAR to eval mode and freeze it
    var.eval()
    for p in var.parameters():
        p.requires_grad_(False)
    var.to(device)

    # return copy.deepcopy(vae),copy.deepcopy(var)
    return vae, var


class VarGeneratedDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        # Load all token indices on CPU (list of 10 tensors, each [N, patch_count])
        self.token_indices = torch.load(os.path.join(data_dir, 'generated_token_indices.pt'), map_location='cpu')
        self.length = self.token_indices[0].shape[0]
        # Precompute image file names
        self.image_files = [os.path.join(data_dir, f"{i:05d}.png") for i in range(self.length)]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_files[idx]
        image = PImage.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        else:
            # Default: ToTensor and normalize to [-1, 1]
            image = transforms.ToTensor()(image)
            image = image * 2 - 1
        # Get tokens for all patch scales (as a list of tensors) - keep on CPU
        tokens = [t[idx].clone() for t in self.token_indices]
        return image, tokens


def main():
    """Main function to run the fine-tuning script."""
    # --- Configuration ---
    config = {
        "GENERATED_DATA_DIR": "[GENERATED_DATA_DIR]",
        "CKPT_SAVE_DIR": "[VAR_CKPT_PATH]",
        "GPU_ID": 0,
        "SEED": 42,
        "TOTAL_STEPS": 20000,
        "BATCH_SIZE": 16,
        "LEARNING_RATE": 5e-5,
        "IMAGE_SIZE": 256,
        "LOG_IMAGE_FREQ": 50,  # Log images every N steps
        "TRAIN_ENCODER": True,
        "TRAIN_DECODER": False,
        "LPIPS_WEIGHT": 0.0,
        "MSE_IMG_WEIGHT": 0.0,
        "MSE_FEAT_WEIGHT": 1.0,
        "MODEL_DEPTH": 16,
        "CFG": 4.0,
        "MORE_SMOOTH": False,
    }

    # --- Initialize Wandb ---
    wandb.init(project="vqvae-finetuning", config=config)
    config = wandb.config  # Allow wandb to update config

    GENERATED_DATA_DIR = config.GENERATED_DATA_DIR
    CKPT_SAVE_DIR = config.CKPT_SAVE_DIR
    GPU_ID = config.GPU_ID
    SEED = config.SEED
    TOTAL_STEPS = config.TOTAL_STEPS
    BATCH_SIZE = config.BATCH_SIZE
    LEARNING_RATE = config.LEARNING_RATE
    IMAGE_SIZE = config.IMAGE_SIZE
    LOG_IMAGE_FREQ = config.LOG_IMAGE_FREQ
    TRAIN_ENCODER = config.TRAIN_ENCODER
    TRAIN_DECODER = config.TRAIN_DECODER
    MODEL_DEPTH = config.MODEL_DEPTH
    CFG = config.CFG
    MORE_SMOOTH = config.MORE_SMOOTH

    device = setup_environment(GPU_ID, SEED)
    vae, var = load_models(device, MODEL_DEPTH)
    
    # --- Watch Model ---
    wandb.watch(vae, log="all", log_freq=100)

    # --- Freeze/Unfreeze Model Components ---
    print("Setting trainable parameters...")
    for p in vae.parameters():
        p.requires_grad_(False)

    trainable_params = []
    if TRAIN_ENCODER:
        print("Unfreezing ENCODER.")
        vae.encoder.train()
        for p in vae.encoder.parameters():
            p.requires_grad_(True)
        trainable_params.extend(vae.encoder.parameters())
        vae.quant_conv.train()
        vae.quant_conv.requires_grad_(True)
        trainable_params.extend(vae.quant_conv.parameters())
    
    if TRAIN_DECODER:
        print("Unfreezing DECODER.")
        for p in vae.decoder.parameters():
            p.requires_grad_(True)
        trainable_params.extend(vae.decoder.parameters())

    if not trainable_params:
        print("WARNING: No components selected for training. Exiting.")
        return

    vae.to(device)

    # --- Setup Data ---
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2 - 1),  # Normalize to [-1, 1]
    ])
    dataset = VarGeneratedDataset(GENERATED_DATA_DIR, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)

    # --- Setup Optimizer and Scheduler ---
    print("Setting up optimizer...")
    optimizer = torch.optim.AdamW(
        trainable_params, lr=LEARNING_RATE, weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=TOTAL_STEPS
    )

    # --- Training Loop ---
    print("\nStarting fine-tuning...")
    vae.train()
    lpips_loss_fn = lpips.LPIPS(net="vgg").to(device)
    
    progress_bar = tqdm(range(TOTAL_STEPS), desc="Fine-tuning", leave=False)
    data_iter = iter(dataloader)
    for step in progress_bar:
        try:
            image_batch, token_batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            image_batch, token_batch = next(data_iter)
        image_batch = image_batch.to(device)
        # token_batch is a list of 10 tensors, each [B, patch_count]
        # Convert to the format expected by vae.idxBl_to_embedhat (list of tensors on device)
        token_batch = [t.to(device) for t in token_batch]

        optimizer.zero_grad()

        with torch.no_grad():
            # Target: Reconstruct f_hat from the original tokens
            ms_h_BChw = vae.idxBl_to_embedhat(token_batch)
            f_hat_target = vae.quantize.embedhat_to_fhat(
                ms_h_BChw, all_to_max_scale=True, last_one=True
            ).detach()
            # Also decode the target f_hat to get the target image reconstruction
            image_reconstructed_target = vae.decoder(vae.post_quant_conv(f_hat_target)).clamp(-1, 1)

        # Prediction: Get f_hat from passing the image through the VAE's encoder and quantizer
        f_pred = vae.quant_conv(vae.encoder(image_reconstructed_target))
        f_hat_pred, _, _ = vae.quantize(f_pred)
        # reconstruct the image from the f_hat_pred
        image_reconstructed_pred = vae.decoder(vae.post_quant_conv(f_hat_pred)).clamp(-1, 1)

        # Calculate Loss
        mse_feat_loss = F.mse_loss(f_pred, f_hat_target)
        # Compare the predicted reconstruction to the target reconstruction
        mse_img_loss = F.mse_loss(image_reconstructed_pred, image_reconstructed_target)
        lpips_val = lpips_loss_fn(image_reconstructed_pred, image_reconstructed_target).mean()

        loss = (
            config.MSE_FEAT_WEIGHT * mse_feat_loss
            + config.MSE_IMG_WEIGHT * mse_img_loss
            + config.LPIPS_WEIGHT * lpips_val
        )

        loss.backward()
        optimizer.step()
        scheduler.step()

        # --- Log Metrics to Wandb ---
        wandb.log(
            {
                "total_loss": loss.item(),
                "mse_feat_loss": mse_feat_loss.item(),
                "mse_img_loss": mse_img_loss.item(),
                "lpips_loss": lpips_val.item(),
                "lr": scheduler.get_last_lr()[0],
                "step": step,
            }
        )

        progress_bar.set_postfix(
            {"loss": f"{loss.item():.6f}", "lr": scheduler.get_last_lr()[0]}
        )

        # --- Log Images to Wandb periodically ---
        if step % LOG_IMAGE_FREQ == 0:
            with torch.no_grad():
                original_images = [wandb.Image(img) for img in image_batch.cpu()]
                reconstructed_pred_images = [
                    wandb.Image(img) for img in image_reconstructed_pred.cpu()
                ]
                reconstructed_gt_images = [
                    wandb.Image(img) for img in image_reconstructed_target.cpu()
                ]
                wandb.log(
                    {
                        "Original Images": original_images,
                        "Reconstructed Prediction": reconstructed_pred_images,
                        "Reconstructed Ground Truth": reconstructed_gt_images,
                        "step": step,
                    }
                )

    print("Fine-tuning completed.")

    # --- Save Model ---
    # Create filename based on config parameters
    train_components = []
    if TRAIN_ENCODER:
        train_components.append("encoder")
    if TRAIN_DECODER:
        train_components.append("decoder")
    
    filename = f"vqvae_finetuned_fmap_lpips{config.LPIPS_WEIGHT}_mse_img{config.MSE_IMG_WEIGHT}_mse_feat{config.MSE_FEAT_WEIGHT}_steps{TOTAL_STEPS}_{'_'.join(train_components)}_lr{LEARNING_RATE}_bs{BATCH_SIZE}_with_dataset.pth"
    output_path = osp.join(CKPT_SAVE_DIR, filename)
    torch.save(vae.state_dict(), output_path)
    print(f"Fine-tuned VQ-VAE model saved to {output_path}")

    # --- Finish Wandb Run ---
    wandb.finish()


if __name__ == "__main__":
    main() 