import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
import pickle
import argparse
from tqdm import tqdm

# Function to preprocess images for VGG
def vgg_preprocess(batch):
    """
    Preprocesses the batch of images to match the input requirements of VGG networks.
    Assumes the input images are in the range [-1, 1].
    """
    # Convert from [-1, 1] to [0, 1]
    batch = (batch + 1) / 2.0
    # Resize to 224x224 as expected by VGG
    batch = F.interpolate(batch, size=(224, 224), mode='bilinear', align_corners=False)
    # Normalize using ImageNet mean and std
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std= [0.229, 0.224, 0.225])
    # Apply normalization
    batch = torch.stack([normalize(img) for img in batch])
    return batch

def compute_vgg_loss(vgg, instancenorm, img, target):
    """
    Computes the DIPD (Domain Invariant Perceptual Distance) between img and target using VGG features.
    """
    img_vgg = vgg_preprocess(img)
    target_vgg = vgg_preprocess(target)
    img_fea = vgg(img_vgg)
    target_fea = vgg(target_vgg)
    loss = torch.mean((instancenorm(img_fea) - instancenorm(target_fea)) ** 2)
    return loss

def main():
    parser = argparse.ArgumentParser(description='Compute DIPD metric for a given checkpoint.')
    parser.add_argument('checkpoint_path', type=str, help='Path to the checkpoint file.')
    args = parser.parse_args()

    # Load the generator model
    with open(args.checkpoint_path, 'rb') as f:
        G = pickle.load(f)['G_ema'].cuda()  # Load the generator model
    G.eval()  # Set the generator to evaluation mode

    batch_size = 16
    num_iterations = 1000

    # Load VGG model for perceptual loss
    vgg = models.vgg19(pretrained=True).features.cuda().eval()
    for param in vgg.parameters():
        param.requires_grad = False
    instancenorm = nn.InstanceNorm2d(512, affine=False).cuda()

    dipd_total = 0.0

    # Prepare constant conditional inputs c0 and c1
    c0 = torch.tensor([[1, 0]] * batch_size, dtype=torch.int64).cuda()  # Domain A
    c1 = torch.tensor([[0, 1]] * batch_size, dtype=torch.int64).cuda()  # Domain B

    # Progress bar for iterations
    for _ in tqdm(range(num_iterations), desc='Computing DIPD'):
        # Generate latent vectors z
        z = torch.randn([batch_size, G.z_dim]).cuda()

        # Generate images for both domains
        img = G(z, c0)  # Images conditioned on c0
        target = G(z, c1)  # Images conditioned on c1

        # Compute DIPD for the batch
        loss = compute_vgg_loss(vgg, instancenorm, img, target)
        dipd_total += loss.item()

    dipd_average = dipd_total / num_iterations
    print(f'Average DIPD over {num_iterations} iterations: {dipd_average:.6f}')

if __name__ == '__main__':
    main()
