# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
#          Alexander Tong

import os
import sys
import numpy as np

import matplotlib.pyplot as plt
import torch
from absl import app, flags
from torchdiffeq import odeint
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_string("input_dir", "/slurm-storage/teoreu/git/variance_flows/train_cifar10/new_results", help="output_directory")
flags.DEFINE_string("model", "sbm", help="flow matching model type")
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 240000, help="training steps")
flags.DEFINE_integer("batch_size", 10, help="batch size for evaluation")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
flags.DEFINE_string("output_file", "log_likelihood_results.txt", help="File to save log likelihood results")
FLAGS(sys.argv)


# Define the model
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
model = FLAGS.model
new_net = UNetModelWrapper(
    dim=(3, 32, 32),
    num_res_blocks=2,
    num_channels=FLAGS.num_channel,
    channel_mult=[1, 2, 2, 2],
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="16",
    dropout=0.1,
).to(device)


# Load the model
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/sbm_backward_cifar10_weights_step_{FLAGS.step}.pt"
print("path: ", PATH)
checkpoint = torch.load(PATH)
state_dict = checkpoint["ema_model"]
try:
    new_net.load_state_dict(state_dict)
except RuntimeError:
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k[7:]] = v
    new_net.load_state_dict(new_state_dict)
new_net.eval()

t_min = 0
if model == 'vpfm':
    t_min = 1e-5

# Setup the ODE solver
if FLAGS.integration_method == "euler":
    node = NeuralODE(new_net, solver=FLAGS.integration_method)
else:
    # For higher-order methods like dopri5, the neural ODE is implicitly created during odeint call
    pass


def calculate_log_likelihood(model, x, t_span, method="dopri5", rtol=1e-5, atol=1e-5):
    """Calculate log likelihood of images under the model.
    
    Args:
        model: The ODE function
        x: Input images from the CIFAR10 dataset (normalized to [-1, 1])
        t_span: Time span for integration
        method: Integration method
        rtol, atol: Tolerances for the ODE solver
        
    Returns:
        Negative log likelihood of the images (NLL)
    """
    with torch.no_grad():
        # Forward pass through the ODE to map CIFAR10 images to gaussian space
        if method == "euler":
            # For Euler method, use the neural ODE trajectory function
            traj = node.trajectory(x, t_span=t_span)
            z = traj[-1, :]  # Final state in Gaussian space
        else:
            # For higher order methods, use odeint directly
            traj = odeint(
                model, x, t_span, rtol=rtol, atol=atol, method=method
            )
            z = traj[-1, :]  # Final state in Gaussian space
        
        # Reshape z to a 2D tensor (batch_size, feature_dim)
        print(torch.mean(z))
        print(torch.var(z))
        batch_size = z.shape[0]
        z_flat = z.reshape(batch_size, -1)
        feature_dim = z_flat.shape[1]  # 3*32*32 for CIFAR10
        
        # Define standard normal distribution similar to how it's done in data.py
        from torch.distributions import MultivariateNormal
        
        # For computational efficiency with high-dimensional data (CIFAR10 images have 3*32*32 = 3072 dims)
        # we use a diagonal covariance matrix (Independent Normal) rather than full MultivariateNormal
        from torch.distributions import Independent, Normal
        
        # Create standard normal distribution
        # Using diagonal standard normal for efficiency (same as MultivariateNormal with identity covariance)
        gaussian_dist = Independent(Normal(loc=torch.zeros_like(z_flat), scale=torch.ones_like(z_flat)), 1)
        
        # Compute log probability under the standard normal
        log_likelihood_gaussian = gaussian_dist.log_prob(z_flat)
        
        # Compute negative log-likelihood (NLL)
        # Lower NLL means better fit
        nll = -log_likelihood_gaussian
        
        # Note: For a complete generative model, we would need to account for
        # the change of variables formula by adding the log determinant of the Jacobian:
        # log p(x) = log p(z) + log |det(dz/dx)|
        # 
        # However, computing the full Jacobian for high-dimensional data is computationally expensive.
        # For now, we return just the NLL under the Gaussian approximation.
        
        return nll


# Load CIFAR10 dataset
print("Loading CIFAR10 dataset...")
dataset = datasets.CIFAR10(
    root="./data",
    train=True,  # Use training set which has 50,000 images
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize to [-1, 1]
        ]
    ),
)

print('len full dataset: ', len(dataset))
permuted_id = torch.randperm(len(dataset))
print(permuted_id)

subdataset_size = 100
subset_number = 1
selected_id = permuted_id[subset_number*subdataset_size: (subset_number+1)*subdataset_size]

print('selected range of permuted to build subdataset: ', subset_number*subdataset_size, (subset_number+1)*subdataset_size)

subset_cifar10 = torch.utils.data.Subset(dataset, selected_id)
print('len subdataset: ', len(subset_cifar10))


dataloader = DataLoader(
    subset_cifar10,
    batch_size=FLAGS.batch_size,
    shuffle=False,  # No need to shuffle for evaluation
    num_workers=4,
    drop_last=False,  # Keep all samples
)

print(f"Computing log likelihood for {len(dataset)} CIFAR10 images")

# Setup for integration
t_span = torch.linspace(t_min, 1, FLAGS.integration_steps + 1).to(device) if FLAGS.integration_method == "euler" else torch.linspace(t_min, 1, 2).to(device)

# Compute NLL for all batches
all_nlls = []
total_samples = 0

for batch_idx, (images, _) in enumerate(dataloader):
    images = images.to(device)
    batch_nlls = calculate_log_likelihood(
        new_net, 
        images, 
        t_span, 
        method=FLAGS.integration_method,
        rtol=FLAGS.tol,
        atol=FLAGS.tol
    )
    
    all_nlls.append(batch_nlls)
    total_samples += images.size(0)
    
    if (batch_idx + 1) % 10 == 0:
        print(f"Processed {total_samples}/{len(dataset)} images")

# Concatenate all negative log likelihoods
all_nlls = torch.cat(all_nlls, dim=0)

# Calculate statistics
mean_nll = all_nlls.mean().item()
std_nll = all_nlls.std().item()

print("\nNegative Log Likelihood (NLL) Calculation Complete")
print(f"Mean NLL: {mean_nll:.4f} (lower is better)")
print(f"Std Dev NLL: {std_nll:.4f}")

# Save results to a file
with open(FLAGS.output_file, "w") as f:
    f.write(f"Model: {FLAGS.model}\n")
    f.write(f"Checkpoint: {PATH}\n")
    f.write(f"Integration Method: {FLAGS.integration_method}\n")
    f.write(f"Integration Steps: {FLAGS.integration_steps}\n")
    f.write(f"Number of Images: {len(dataset)}\n")
    f.write(f"Mean NLL: {mean_nll:.4f}\n")
    f.write(f"Std Dev NLL: {std_nll:.4f}\n")

print(f"Results saved to {FLAGS.output_file}")

# For completeness, you could also save the individual NLLs
np.save(f"{FLAGS.model}_nll_values.npy", all_nlls.cpu().numpy())