# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
#          Alexander Tong

import os
import sys

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from absl import app, flags
from cleanfid import fid
from torchdiffeq import odeint
from torchdyn.core import NeuralODE

from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_string("input_dir", "./results", help="output_directory")
flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_integer("integration_steps", 20, help="number of inference steps")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 400000, help="training steps")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
flags.DEFINE_integer("batch_size_fid", 1024, help="Batch size to compute FID")

FLAGS(sys.argv)


# Define the models
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

# Create both pretrained and residual models
pretrained_net = UNetModelWrapper(
    dim=(3, 32, 32),
    num_res_blocks=2,
    num_channels=FLAGS.num_channel,
    channel_mult=[1, 2, 2, 2],
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="16",
    dropout= 0.1,
).to(device)

# Define the checkpoint path
# PATH = f"{FLAGS.input_dir}/{FLAGS.model}/myfinetune_{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt"
# PATH = "results3/otcfm/otcfm_cifar10_weights_step_41020000.pt"
# PATH = "results_MLE/otcfm/otcfm_cifar10_weights_step_95120.0.pt"
PATH = "results_MLE_Res/otcfm/otcfm_cifar10_weights_step_400.pt"
print("path: ", PATH)

# Load the checkpoint
checkpoint = torch.load(PATH, map_location=device)

# Check checkpoint to determine residual model architecture
res_model_type = checkpoint.get("res_model_type", "unet")

if res_model_type == "mlp":
    # Simple MLP model for residual learning
    class MLPResidualModel(nn.Module):
        def __init__(self, input_dim=3*32*32, hidden_dim=4*512, output_dim=3*32*32):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
            
        def forward(self, t, x, **kwargs):
            # t is time, x is the input tensor
            batch_size = x.shape[0]
            x_flat = x.view(batch_size, -1)
            output = self.net(x_flat)
            return output.view_as(x)
    
    res_net = MLPResidualModel().to(device)
else:  # default to unet
    # res_net = UNetModelWrapper(
    #     dim=(3, 32, 32),
    #     num_res_blocks=2,
    #     num_channels=FLAGS.num_channel,
    #     channel_mult=[1, 2, 2, 2],
    #     num_heads=4,
    #     num_head_channels=64,
    #     attention_resolutions="16",
    #     dropout= 0.1,
    # ).to(device)
    res_net = UNetModelWrapper(
            dim=(3, 32, 32),
            num_res_blocks=1,
            num_channels=FLAGS.num_channel,
            channel_mult=[1, 2],
            num_heads=4,
            num_head_channels=64,
            # attention_resolutions="",
            dropout=0.1,
        ).to(device)

# Load pretrained model state_dict
pretrained_state_dict = checkpoint["pretrained_model"]
try:
    pretrained_net.load_state_dict(pretrained_state_dict)
except RuntimeError:
    from collections import OrderedDict
    new_pretrained_state_dict = OrderedDict()
    for k, v in pretrained_state_dict.items():
        if k.startswith('module.'):
            k = k[7:]  # remove 'module.' prefix
        new_pretrained_state_dict[k] = v
    pretrained_net.load_state_dict(new_pretrained_state_dict)

# Load residual model EMA state_dict
res_state_dict = checkpoint["ema_res_model"]
try:
    res_net.load_state_dict(res_state_dict)
except RuntimeError:
    from collections import OrderedDict
    new_res_state_dict = OrderedDict()
    for k, v in res_state_dict.items():
        if k.startswith('module.'):
            k = k[7:]  # remove 'module.' prefix
        new_res_state_dict[k] = v
    res_net.load_state_dict(new_res_state_dict)

pretrained_net.eval()
res_net.eval()


# Define the integration methods if euler is used
if FLAGS.integration_method == "euler":
    pretrained_node = NeuralODE(pretrained_net, solver=FLAGS.integration_method)
    res_node = NeuralODE(res_net, solver=FLAGS.integration_method)


def gen_1_img(unused_latent):
    with torch.no_grad():
        x = torch.randn(FLAGS.batch_size_fid, 3, 32, 32, device=device)
        
        # Two-stage generation: pretrained_net -> res_net
        if FLAGS.integration_method == "euler":
            print("Use method: ", FLAGS.integration_method, " steps: ", FLAGS.integration_steps)
            # First stage: pretrained model
            t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1, device=device)
            pretrained_traj = pretrained_node.trajectory(x, t_span=t_span)
            intermediate_x = pretrained_traj[-1]
            
            # Second stage: residual model
            res_traj = res_node.trajectory(intermediate_x, t_span=t_span)
            final_x = res_traj[-1]
        else:
            print("Use method: ", FLAGS.integration_method)
            # First stage: pretrained model
            t_span = torch.linspace(0, 1, 2, device=device)
            pretrained_traj = odeint(
                pretrained_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
            )
            intermediate_x = pretrained_traj[-1, :]
            
            # Second stage: residual model
            res_traj = odeint(
                res_net, intermediate_x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
            )
            final_x = res_traj[-1, :]
        
        img = (final_x * 127.5 + 128).clip(0, 255).to(torch.uint8)
        return img


print("Start computing FID")
score = fid.compute_fid(
    gen=gen_1_img,
    dataset_name="cifar10",
    batch_size=FLAGS.batch_size_fid,
    dataset_res=32,
    num_gen=FLAGS.num_gen,
    dataset_split="train",
    mode="legacy_tensorflow",
)
print()
print("FID has been computed")
# print()
# print("Total NFE: ", new_net.nfe)
print()
print("FID: ", score)
