from random import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import argparse
import numpy as np
from torch.utils.data import DataLoader, Dataset
import tqdm
import glob
import time
from torchvision import transforms
from PIL import Image
from infinity.dataset.dataset_t2i_iterable import transform
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
from tools.run_infinity import add_common_arguments

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)


class ProcessDataset(Dataset):
    def __init__(self, preprocessed_dir, num_samples=None):
        self.preprocessed_dir = preprocessed_dir
        print(f"Searching for file pairs in: {preprocessed_dir}")
        start_time = time.time()
        self.recon_files = sorted(glob.glob(os.path.join(preprocessed_dir, "*.png")))

        self.file_pairs = []
        skipped_count = 0
        for recon_f in tqdm.tqdm(self.recon_files, desc="Finding file pairs"):
            basename = os.path.basename(recon_f).replace(".png", "")
            target_f = os.path.join(preprocessed_dir, f"{basename}_summed_codes.pt")
            if os.path.exists(target_f):
                self.file_pairs.append((recon_f, target_f))
            else:
                skipped_count += 1
        end_time = time.time()
        print(f"Found {len(self.file_pairs)} file pairs.")
        if skipped_count > 0:
            print(f"Skipped {skipped_count} files because corresponding target was missing.")
        print(f"File pair search took {end_time - start_time:.2f} seconds.")
        if num_samples is not None and num_samples < len(self.file_pairs):
            self.file_pairs = self.file_pairs[:num_samples]
            print(f"Limiting dataset to first {num_samples} samples.")
    def __len__(self):
        return len(self.file_pairs)

    def __getitem__(self, idx):
        recon_path, target_path = self.file_pairs[idx]
        try:
            recon_image = Image.open(recon_path)  # Verify image integrity
            target_tensor = torch.load(target_path, map_location='cpu')
            return transform(recon_image, 1024, 1024), target_tensor
        except Exception as e:
            print(f"Error loading item at index {idx} (Paths: {recon_path}, {target_path}): {e}. Returning None.")
            return None, None


def collate_fn(batch):
    batch = list(filter(lambda x: x is not None and x[0] is not None and x[1] is not None, batch))

    try:
        recon_tensors = torch.stack([item[0] for item in batch], dim=0)
        target_tensors = torch.stack([item[1] for item in batch], dim=0)
        return recon_tensors, target_tensors
    except Exception as e:
        return None, None
def load_visual_tokenizer(args, device):
    print(f"Loading pretrained model from {args.vae_path}")
    if args.vae_type in [14,16,18,20,24,32,64]:
        from infinity.models.bsq_vae.vae import vae_model
        schedule_mode = "dynamic"
        codebook_dim = args.vae_type
        codebook_size = 2**codebook_dim
        if args.apply_spatial_patchify:
            patch_size = 8
            encoder_ch_mult=[1, 2, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4]
        else:
            patch_size = 16
            encoder_ch_mult=[1, 2, 4, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4, 4]
        vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, 
                        encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
    return vae
def fine_tune_encoder_on_preprocessed(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    log_file_path = args.log_file
    h_div_w = 1/1 # aspect ratio, height:width

    h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(args.h_div_w_template-h_div_w))]
    scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
    scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
    model = load_visual_tokenizer(args, device)

    model.eval() 

    for param in model.parameters():
        param.requires_grad = False  # Freeze all parameters

    model.encoder.train()  # Set encoder to train mode
    for param in model.encoder.parameters():
        param.requires_grad = True  # Unfreeze 
    trainable_model_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(trainable_model_params, lr=args.finetune_lr, betas=(0.9, 0.95))

    start_epoch = 0

    print("Initializing Dataset...")

    preprocessed_dataset = ProcessDataset(preprocessed_dir=args.preprocessed_dir)
    if len(preprocessed_dataset) == 0:
        print(f"Error: No valid data found in {args.preprocessed_dir}. Please check the directory and file naming convention.")
        sys.exit(1)

    print(f"Dataset size: {len(preprocessed_dataset)}")
    print("Initializing DataLoader...")
    dataloader = DataLoader(preprocessed_dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            collate_fn=collate_fn,
                            drop_last=True)
    print(f"DataLoader initialized with Batch Size: {args.batch_size}, Num Workers: {args.num_workers}")

    criterion = nn.MSELoss()#

    print(f"Starting training from epoch {start_epoch + 1} for {args.epochs} total epochs.")
    for epoch in range(start_epoch, args.epochs):
        print(f"\n--- Epoch {epoch+1}/{args.epochs} ---")
        model.encoder.train()

        progress_bar = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}", unit="batch", leave=False)
        total_loss = 0.0
        num_valid_batches = 0
        batch_load_errors = 0

        for batch_idx, (gen_img, f_z) in enumerate(progress_bar):

            if gen_img is None or f_z is None:
                batch_load_errors += 1
                continue

            gen_img = gen_img.to(device, non_blocking=True)
            f_z = f_z.to(device, non_blocking=True).squeeze(1)
            optimizer.zero_grad()

            try:
                h, _, _ = model.encode_for_raw_features(gen_img, scale_schedule)

                loss = criterion(h, f_z.detach())

                loss.backward()

                if args.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(
                        filter(lambda p: p.requires_grad, model.parameters()),
                        args.grad_clip
                    )

                optimizer.step()

                total_loss += loss.item()
                num_valid_batches += 1

                progress_bar.set_postfix(
                    loss=f"{loss.item():.4f}",
                    avg_loss=f"{total_loss / num_valid_batches:.4f}"
                )

            except Exception as e:
                print(f"\nError during training step in batch {batch_idx}: {e}")
                print(f"  Input shape: {gen_img.shape if gen_img is not None else 'None'}")
                print(f"  Target shape: {f_z.shape if f_z is not None else 'None'}")
                import traceback
                traceback.print_exc()
                optimizer.zero_grad()
                continue

        progress_bar.close()

        if batch_load_errors > 0:
             print(f"Warning: Skipped {batch_load_errors} batches in epoch {epoch+1} due to loading/collation errors.")

        if num_valid_batches > 0:
            avg_epoch_loss = total_loss / num_valid_batches

            print(f"Epoch {epoch+1} finished. Processed {num_valid_batches} batches.")
            print(f"  Avg Loss: {avg_epoch_loss:.6f}")

            try:
                with open(log_file_path, 'a') as log_f:
                    log_f.write(f"Epoch {epoch+1}: Avg Loss: {avg_epoch_loss:.6f}\n")
            except Exception as e:
                print(f"Warning: Could not write to log file {log_file_path}: {e}")

        else:
            print(f"Epoch {epoch+1} finished, but no valid batches were processed. Check data loading and collation.")

        if (epoch + 1) % args.save_interval == 0 or (epoch + 1) == args.epochs:
            encoder_save_path = os.path.join(args.output_dir, f"encoder_finetuned_target_zq_epoch_{epoch+1}.pt")
            try:
                save_data = {
                    'epoch': epoch + 1,
                    'encoder_state_dict': model.encoder.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'args': args
                }
                torch.save(save_data, encoder_save_path)
                print(f"Checkpoint saved: {encoder_save_path}")
            except Exception as e:
                print(f"Error saving checkpoint: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_common_arguments(parser)
    
    parser.add_argument("--finetune-lr", type=float, default=1e-5, help="Learning rate for the Adam optimizer")
    parser.add_argument("--epochs", type=int, default=10, help="Total number of epochs to train")
    parser.add_argument("--batch-size", type=int, default=8, help="Batch size per GPU")
    parser.add_argument("--grad-clip", type=float, default=0, help="Gradient clipping value (0 to disable)")

    parser.add_argument("--num-workers", type=int, default=4, help="Number of DataLoader workers")
    parser.add_argument("--output-dir", type=str, default="./finetuned_encoder_target_zq_dual_loss", help="Directory to save checkpoints and logs")
    parser.add_argument("--save-interval", type=int, default=1, help="Save checkpoint every N epochs")
    parser.add_argument("--log-file", type=str, default="training_loss_dual.log", help="Filename for logging epoch average loss (relative to output-dir)")
    parser.add_argument('--preprocessed-dir', type=str, required=True, help="Directory containing preprocessed .png and _summed_codes.pt files")
    args = parser.parse_args()
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    os.makedirs(args.output_dir, exist_ok=True)

    print("----- Configuration -----")
    for arg, value in vars(args).items():
        print(f"{arg}: {value}")
    print("-------------------------")


    fine_tune_encoder_on_preprocessed(args)

    print("Training finished.")
