from pathlib import Path
import json
from models.model_configs import instantiate_model
import torch
from training.eval_loop import CFGScaledModel
from flow_matching.solver.ode_solver import ODESolver
from matplotlib import pyplot as plt


seed = 42
torch.manual_seed(seed)

checkpoint_path = Path("./output_dir/checkpoint-cond-699.pth")
args_filepath = checkpoint_path.parent / 'args.json'
with open(args_filepath, 'r') as f:
    args_dict = json.load(f)

model = instantiate_model(architechture=args_dict['dataset'], is_discrete='discrete_flow_matching' in args_dict  and args_dict['discrete_flow_matching'],
                          use_ema=args_dict['use_ema'])
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.train(False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Number of GPUs being used:", torch.cuda.device_count())
model.to(device=device)

cfg_weighted_model = CFGScaledModel(model=model)

solver = ODESolver(velocity_model=cfg_weighted_model)
ode_opts = args_dict['ode_options']
ode_opts["method"] = args_dict['ode_method']

channels = 3

# Set the sampling resolution corresponding to the model
if 'train_blurred_64' in args_dict['data_path'] and args_dict['dataset'] == 'imagenet':
    sample_resolution = 64
elif 'train_blurred_32' in args_dict['data_path'] or args_dict['dataset'] == 'cifar10':
    sample_resolution = 32



batch_size = 5
x_0 = torch.randn([batch_size, channels, sample_resolution, sample_resolution], dtype=torch.float32, device=device)

# Plotting the Noise
plt.figure(figsize=(15, 3))
for i in range(batch_size):
    plt.subplot(1, batch_size, i + 1)
    img = x_0[i].cpu().permute(1, 2, 0).numpy()  # Convert tensor to numpy for visualization
    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0,1] for visualization
    plt.imshow(img)

    plt.axis('off')

plt.suptitle("Randomly Generated Noise (No Correlation)", fontsize=24)
plt.tight_layout()
plt.subplots_adjust(bottom=0.14)  # Make room for the legend and title
plt.show()

for label in range(10):
    labels = torch.full((batch_size,), label, dtype=torch.int32, device=device)

    time_steps = 30
    time_grid = torch.linspace(0, 1, time_steps).to(device=device)

    synthetic_samples = solver.sample(
        time_grid=time_grid,
        x_init=x_0,
        method=args_dict['ode_method'],
        atol=args_dict['ode_options']['atol'] if 'atol' in args_dict['ode_options'] else None,
        rtol=args_dict['ode_options']['rtol'] if 'rtol' in args_dict['ode_options'] else None,
        step_size=args_dict['ode_options']['step_size'] if 'step_size' in args_dict['ode_options'] else None,
        label=labels,
        return_intermediates=False,
        cfg_scale=args_dict['cfg_scale'],
    )

    # Scaling to [0, 1] from [-1, 1]
    synthetic_samples = torch.clamp(
        synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0
    )
    synthetic_samples = torch.floor(synthetic_samples * 255) / 255.0

    plt.figure(figsize=(15, 3))
    for j in range(batch_size):
        plt.subplot(1, batch_size, j + 1)  # 1 row, batch_size columns
        image = synthetic_samples[j].cpu().permute(1, 2, 0).numpy()
        plt.imshow(image)
        plt.axis('off')

    time_value = time_grid[-1].item()  # Use the last time step (index 9)
    plt.suptitle(f'Images Generated for Label {label}', fontsize=24)
    plt.tight_layout()

    plt.show()

print(f"Generated Using Seed: {seed}")