"""
Evaluate DINOv3 tokenizer performance by computing reconstruction metrics.

Metrics:
- rFID (Reconstruction FID)
- PSNR (Peak Signal-to-Noise Ratio) 
- LPIPS (Learned Perceptual Image Patch Similarity)
- SSIM (Structural Similarity Index)
"""

import os
import argparse
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torchmetrics import StructuralSimilarityIndexMeasure
from torchvision.datasets import ImageFolder
from torchvision import transforms
from omegaconf import OmegaConf
import logging
from typing import List

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class DINOv3Evaluator:
    """Evaluator for DINOv3 tokenizer."""

    def __init__(self, config_path: str, data_path: str, output_path: str, ckpt_path: str, mask_ratio: float):
        self.config = OmegaConf.load(config_path)
        self.config.ckpt_path = ckpt_path
        if self.config.model.params.extra_vit_config is not None:
            self.config.model.params.extra_vit_config.mask_ratio = mask_ratio

        # Initialize distributed
        self._init_distributed()
        self.device = torch.device(f'cuda:{self.local_rank}')

        # Setup model
        self.model = self._load_dinov3()
        self.transform = self._get_transform()
        self.dataset, self.dataloader = self._prepare_data()

        # Setup output directories
        self.save_dir, self.ref_path = self._setup_output_dirs()

        # Metrics
        self.lpips = self._load_lpips().eval()
        self.ssim_metric = StructuralSimilarityIndexMeasure(data_range=(-1.0, 1.0)).to(self.device)

    def _init_distributed(self):
        dist.init_process_group(backend='nccl')
        self.local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(self.local_rank)

    def _load_dinov3(self):
        from SVG.svg_autoencoder.tokenizer.svg_autoencoder import DINO_DECODER
        model = DINO_DECODER(self.config).load().model.to(self.device).eval()
        if self.local_rank == 0:
            logger.info("Loaded DINOv3 model")
        return model

    def _load_lpips(self):
        from models.lpips import LPIPS
        return LPIPS().to(self.device)

    def _get_transform(self):
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def _prepare_data(self):
        dataset = ImageFolder(root=self.data_path, transform=self.transform)
        sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=self.local_rank)
        dataloader = DataLoader(dataset, batch_size=200, shuffle=False, num_workers=4, sampler=sampler)
        return dataset, dataloader

    def _setup_output_dirs(self):
        ckpt_epoch = ".".join(self.config.ckpt_path.split('/')[-1].split('.')[:-1])
        folder_name = os.path.splitext(os.path.basename(self.config_path))[0] + "_" + ckpt_epoch
        base_dir = os.path.join(self.output_path, folder_name)

        save_dir = os.path.join(base_dir, 'decoded_images')
        ref_path = os.path.join(base_dir, 'ref_images')

        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(ref_path, exist_ok=True)

        if self.local_rank == 0:
            logger.info(f"Output dir: {save_dir}")
            logger.info(f"Reference dir: {ref_path}")

        return save_dir, ref_path

    def encode(self, images: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.model.encode(images)

    def decode(self, latents: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.model.decode(latents)

    def evaluate(self):
        lpips_values, ssim_values = [], []
        all_indices = 0

        if self.local_rank == 0:
            logger.info("Generating reconstructions...")

        for batch in tqdm(self.dataloader, disable=self.local_rank != 0):
            images = batch[0].to(self.device)

            latents = self.encode(images)
            decoded = self.decode(latents)

            # Clamp to uint8
            decoded_imgs = torch.clamp(127.5 * decoded + 128.0, 0, 255)
            decoded_imgs = decoded_imgs.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()

            # Metrics
            lpips_values.append(self.lpips(decoded, images).mean())
            ssim_values.append(self.ssim_metric(decoded, images))

            # Save images
            for i, img in enumerate(decoded_imgs):
                Image.fromarray(img).save(
                    os.path.join(self.save_dir, f"decoded_rank{self.local_rank}_{all_indices + i}.png")
                )
            all_indices += len(decoded_imgs)

        # Aggregate metrics
        lpips_values = torch.tensor(lpips_values).to(self.device)
        ssim_values = torch.tensor(ssim_values).to(self.device)
        dist.all_reduce(lpips_values, op=dist.ReduceOp.AVG)
        dist.all_reduce(ssim_values, op=dist.ReduceOp.AVG)

        avg_lpips = lpips_values.mean().item()
        avg_ssim = ssim_values.mean().item()

        if self.local_rank == 0:
            logger.info(f"Final Metrics: LPIPS={avg_lpips:.3f}, SSIM={avg_ssim:.3f}")

        dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser(description="Evaluate DINOv3 tokenizer performance")
    parser.add_argument('--config_path', type=str, required=True)
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--ckpt_path', type=str, required=True)
    parser.add_argument('--output_path', type=str, required=True)
    parser.add_argument('--mask_ratio', type=float, default=-1)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    evaluator = DINOv3Evaluator(
        config_path=args.config_path,
        data_path=args.data_path,
        output_path=args.output_path,
        ckpt_path=args.ckpt_path,
        mask_ratio=args.mask_ratio,
    )
    evaluator.evaluate()


if __name__ == "__main__":
    main()
