import os
from pathlib import Path
import json
from models.model_configs import instantiate_model
import torch
from training.eval_loop import CFGScaledModel
from flow_matching.solver.ode_solver import ODESolver
from matplotlib import pyplot as plt
import matplotlib.animation as animation
import numpy as np


checkpoint_path = Path("./output_dir/checkpoint-cond-699.pth")
args_filepath = checkpoint_path.parent / 'args.json'
with open(args_filepath, 'r') as f:
    args_dict = json.load(f)

torch.manual_seed(42)

model = instantiate_model(architechture=args_dict['dataset'], is_discrete='discrete_flow_matching' in args_dict  and args_dict['discrete_flow_matching'],
                          use_ema=args_dict['use_ema'])
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.train(False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Number of GPUs being used:", torch.cuda.device_count())

model.to(device=device)

cfg_weighted_model = CFGScaledModel(model=model)

solver = ODESolver(velocity_model=cfg_weighted_model)
ode_opts = args_dict['ode_options']
ode_opts["method"] = args_dict['ode_method']

# Set the sampling resolution corresponding to the model
if 'train_blurred_64' in args_dict['data_path'] and args_dict['dataset'] == 'imagenet':
    sample_resolution = 64
elif 'train_blurred_32' in args_dict['data_path'] or args_dict['dataset'] == 'cifar10':
    sample_resolution = 32

batch_size = 5
labels = torch.tensor(list(range(batch_size)), dtype=torch.int32, device=device) # Required to run the model, but not considered.
labels = torch.full((batch_size,), 4, dtype=torch.int32, device=device)

x_0 = torch.randn([batch_size, 3, sample_resolution, sample_resolution], dtype=torch.float32, device=device)



time_steps = 30
time_grid = torch.linspace(0,1,time_steps).to(device=device)

synthetic_samples = solver.sample(
    time_grid=time_grid,
    x_init=x_0,
    method=args_dict['ode_method'],
    atol=args_dict['ode_options']['atol'] if 'atol' in args_dict['ode_options'] else None,
    rtol=args_dict['ode_options']['rtol'] if 'rtol' in args_dict['ode_options'] else None,
    step_size=args_dict['ode_options']['step_size'] if 'step_size' in args_dict['ode_options'] else None,
    label=labels,
    return_intermediates=True,
    cfg_scale=args_dict['cfg_scale'],
)

# Scaling to [0, 1] from [-1, 1]
synthetic_samples = torch.clamp(
    synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0
)
synthetic_samples = torch.floor(synthetic_samples * 255) / 255.0

def create_timelapse(batch_size, time_steps, synthetic_samples, time_grid):
    """
    Create a timelapse animation of the generated samples over time.
    :param batch_size: Number of images generated
    :param time_steps: int, Number of temporal steps in the ODE solution
    :param synthetic_samples: Tensor of shape [time_steps, batch_size, channels, height, width] containing the image
                              data at each time step
    :param time_grid: Tensor containing the time values corresponding to each step
    :return: None. The function creates:
             1. An animation showing the evolution of all samples from noise to final images
             2. A final static figure showing the completed images
             3. Saves both as files (MP4/GIF for animation, PNG for final image)
    """

    # Convert to numpy for animation
    samples_np = synthetic_samples.cpu().numpy()

    output_path = os.path.join("output_dir", "generated_examples_CIFAR10")
    os.makedirs(output_path, exist_ok=True)
    cols = int(np.ceil(np.sqrt(batch_size)))
    rows = int(np.ceil(batch_size / cols))

    # Create figure and axes
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    if rows * cols > 1:
        axes = axes.flatten()
    else:
        axes = [axes]

    # Initialize with first frame
    images = []
    for j in range(batch_size):
        if j < len(axes):
            img = axes[j].imshow(samples_np[0, j].transpose(1, 2, 0))
            axes[j].set_title(f"Sample {j + 1}")
            axes[j].axis('off')
            images.append(img)

    # Hide any unused axes
    for j in range(batch_size, len(axes)):
        axes[j].axis('off')

    time_text = fig.suptitle(f'Time Step = {time_grid[0].item():.2f}', fontsize=24)

    def update_frame(frame):
        """Update function for animation"""
        for j in range(batch_size):
            if j < len(images):
                images[j].set_array(samples_np[frame, j].transpose(1, 2, 0))
        time_text.set_text(f'Time Step = {time_grid[frame].item():.2f}')
        return images + [time_text]

    # Create animation
    ani = animation.FuncAnimation(
        fig, update_frame, frames=time_steps,
        interval=100, blit=False
    )

    # Create HTML video to display in the notebook
    from IPython.display import HTML, display
    html_video = HTML(ani.to_jshtml())
    display(html_video)

    # You can still save the MP4/GIF files as before
    mp4_path = os.path.join(output_path, "generation_timelapse.mp4")
    try:
        writer = animation.FFMpegWriter(fps=10, metadata=dict(artist='Flow Matching'), bitrate=1800)
        ani.save(mp4_path, writer=writer)
        print(f"MP4 saved to {mp4_path}")
    except (FileNotFoundError, RuntimeError) as e:
        print(f"Could not save MP4 (FFmpeg not found): {e}")
        print("Falling back to GIF format only")

    # Save as GIF
    gif_path = os.path.join(output_path, "generation_timelapse.gif")
    ani.save(gif_path, writer='pillow', fps=10)
    print(f"GIF saved to {gif_path}")
    plt.close(fig)

    # Display final frame (time step = 1) as a separate plot
    last_frame = samples_np[-1]  # Get the last time step

    # Create figure for final images
    final_fig, final_axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    if rows * cols > 1:
        final_axes = final_axes.flatten()
    else:
        final_axes = [final_axes]

    # Plot each sample
    for j in range(batch_size):
        if j < len(final_axes):
            final_axes[j].imshow(last_frame[j].transpose(1, 2, 0))
            final_axes[j].set_title(f"Sample {j + 1}")
            final_axes[j].axis('off')

    # Hide any unused axes
    for j in range(batch_size, len(final_axes)):
        final_axes[j].axis('off')

    final_fig.suptitle(f'Final Generated Images (Time Step = {time_grid[-1].item():.2f})', fontsize=24)

    # plt.tight_layout()

    # Save the final frame plot
    final_image_path = os.path.join(output_path, "final_generated_images.png")
    final_fig.savefig(final_image_path, dpi=300, bbox_inches='tight')
    print(f"Final frame saved to {final_image_path}")

    plt.show()

create_timelapse(batch_size, time_steps, synthetic_samples, time_grid)
