
import glob
import random
from PIL import Image
import os 
import glob

import torch.nn.functional as F
import torch
from torch.utils.data import Dataset
from tools.finetune_vae import load_visual_tokenizer
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from tqdm import tqdm
import time
import argparse
import numpy as np
from infinity.dataset.dataset_t2i_iterable import transform
from tools.run_infinity import dynamic_resolution_h_w, h_div_w_templates, add_common_arguments

class SimpleDataset(Dataset):
    def __init__(self, preprocessed_dir, num_samples=None):
        self.preprocessed_dir = preprocessed_dir
        print(f"Searching for file pairs in: {preprocessed_dir}")
        start_time = time.time()
        self.recon_files = sorted(glob.glob(os.path.join(preprocessed_dir, "*.png")))
        if num_samples is not None and num_samples < len(self.recon_files):
            self.recon_files = self.recon_files[:num_samples]
            print(f"Limiting dataset to first {num_samples} samples.")

    def __len__(self):
        return len(self.recon_files)

    def __getitem__(self, idx):
        recon_path = self.recon_files[idx]
        try:
            print(f"Loading image: {recon_path}", flush=True)
            recon_image = Image.open(recon_path)  # Verify image integrity
            return transform(recon_image, 1024, 1024)
        except Exception as e:
            print(f"Error loading image {recon_path}: {e}")
            # Return a dummy image and label in case of error
            return self.__getitem__((idx -1))


def tokenize_and_reconstruct_batch_latent_optim(vae, scale_schedule, original_image_batch, display_img=False, use_quant=True, lr=1e-2, iters=100):
    image_batch = original_image_batch.clone()
    # initialize the states
    # if use_quant:
    encoded_tokens, hidden_states, quantized_states, codebook_loss = vae.encode_with_internals(image_batch.clone(), scale_schedule)
    # else:
        # encoded_tokens, hidden_states, quantized_states, codebook_loss = vae.encode_without_quant(image_batch.clone().to(device))
    fhat_optim = torch.nn.Parameter(quantized_states.detach()).cuda()
    optimizer = torch.optim.Adam([fhat_optim], lr=lr)
    for i in range(iters):
        optimizer.zero_grad()     
        rec_gen_img = vae.decode(fhat_optim)
        rec_gen_img = torch.clamp(rec_gen_img, -1.0, 1.0)
        loss = F.mse_loss(rec_gen_img, image_batch.clone())
        loss.backward()

        optimizer.step()
        print(loss.item())
        # print(loss)
        if i%50==0:
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.5
    rec =  (((rec_gen_img + 1) / 2).clamp(0, 1).cpu().detach().numpy()[0].transpose(1, 2, 0) * 255.0).astype(np.uint8)
    Image.fromarray(rec).save('tmp_rec.png')
    return rec_gen_img, fhat_optim, loss

def calc_latent_tracer(args, vae, scale_schedule, dataset_name_image_path, device):

    for dataset,path in dataset_name_image_path.items():
        print(f"Calculating latent tracer for {dataset} from {path}", flush=True)
        gen_dataset = SimpleDataset(path, num_samples=args.num_samples)
        def pil_to_tensor_collate(images):
            images = torch.stack(images, dim=0)
            return images
        dataloaders = torch.utils.data.DataLoader(gen_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=pil_to_tensor_collate)
        img_rec_loss_mse_list = []
        for images in tqdm(dataloaders):
            # images: [B, 3, H, W]
            images = images.to(device)
            # No permute needed, already [B, 3, H, W]
            reconstructed_image, _, loss = tokenize_and_reconstruct_batch_latent_optim(vae, scale_schedule, images, False)
            img_rec_loss_mse = torch.mean((reconstructed_image - images) ** 2, dim=[1, 2, 3])
            # 
            img_rec_loss_mse_list.append(img_rec_loss_mse.cpu().detach())
        #Export 
        img_rec_loss_mse_all = torch.cat(img_rec_loss_mse_list, dim=0)
        torch.save(img_rec_loss_mse_all, f"{args.save_folder}/{dataset}_mse_list.pt")


def main():
    parser = argparse.ArgumentParser()
    add_common_arguments(parser)
    parser.add_argument('--num_samples', type=int, default=1000, help='Number of samples to process from each dataset')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size for processing images')
    parser.add_argument('--save_folder', type=str, default='./latenttracer', help='Folder to save results')
    parser.add_argument('--dataset_config', type=str, default='dataset_config.json', help='Path to dataset configuration JSON file')
    args = parser.parse_args()
    device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
    
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    vae = load_visual_tokenizer(args)
    vae.to(device)
    vae.eval()
    import json
    try:
        with open(args.dataset_config, 'r') as f:
            dataset_name_image_path = json.load(f)
        print(f"Loaded dataset configuration from: {args.dataset_config}")
    except FileNotFoundError:
        print(f"Dataset config file '{args.dataset_config}' not found. Please create it or specify a valid path.")
        return
    except json.JSONDecodeError as e:
        print(f"Error parsing JSON config file '{args.dataset_config}': {e}")
        return
    h_div_w = 1/1 # aspect ratio, height:width

    h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]
    scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
    scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
    calc_latent_tracer(args, vae, scale_schedule, dataset_name_image_path, device)

if __name__ == "__main__":
    main()