import os
from pathlib import Path
import json

from torch.backends.mkl import verbose

from models.model_configs import instantiate_model
import torch
from training.eval_loop import CFGScaledModel
from flow_matching.solver.ode_solver import ODESolver
from matplotlib import pyplot as plt


checkpoint_path = Path("./output_dir/checkpoint-1899.pth")
args_filepath = checkpoint_path.parent / 'args.json'
with open(args_filepath, 'r') as f:
    args_dict = json.load(f)

model = instantiate_model(architechture=args_dict['dataset'], is_discrete='discrete_flow_matching' in args_dict  and args_dict['discrete_flow_matching'],
                          use_ema=args_dict['use_ema'])
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.train(False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Number of GPUs being used:", torch.cuda.device_count())
model.to(device=device)

cfg_weighted_model = CFGScaledModel(model=model)

solver = ODESolver(velocity_model=cfg_weighted_model)
ode_opts = args_dict['ode_options']
ode_opts["method"] = args_dict['ode_method']

# Set the sampling resolution corresponding to the model
if 'train_blurred_64' in args_dict['data_path'] and args_dict['dataset'] == 'imagenet':
    sample_resolution = 64
elif 'train_blurred_32' in args_dict['data_path'] or args_dict['dataset'] == 'cifar10':
    sample_resolution = 32

def spatial_mutual_noise(batch_size=5, verbose=0, direction='horizontal', seed=None, filename=0):

    """
    Generate images with spatial mutual information. The images are generated by creating a set of noise images where
    different portions come from two different independent noise samples.

    :param batch_size: Number of images to generate.
    :param verbose: If > 0, displays generate noise latents.
    :param direction: 'horizontal' or 'vertical' - determines how noise patterns are spliced.
    :param seed: Random seed for reproducibility, if None a random seed is generated.
    :param filename: Suffix for the output filename
    :return: None. The function:
             1. Creates noise latents with controlled spatial correlations.
             2. Generates images using flow matching from these noise latents.
    """

    if seed is None:
        seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
    torch.manual_seed(seed)

    def horizontal_splice(noise_0, noise_1):
        # Initialize x0
        x_0 = torch.zeros([batch_size, 3, sample_resolution, sample_resolution], dtype=torch.float32, device=device)
        # Image 1: Pure noise_0
        x_0[0] = noise_0
        # Image 2: 75% noise_0, 25% noise_1
        x_0[1] = noise_0.clone()
        x_0[1, :, (sample_resolution * 3) // 4:, :] = noise_1[:, :, (sample_resolution * 3) // 4:, :]
        # Image 3: 50% noise_0, 50% noise_1
        x_0[2] = noise_0.clone()
        x_0[2, :, sample_resolution // 2:, :] = noise_1[:, :, sample_resolution // 2:, :]
        # Image 4: 25% noise_0, 75% noise_1
        x_0[3] = noise_1.clone()
        x_0[3, :, :sample_resolution // 4, :] = noise_0[:, :, :sample_resolution // 4, :]
        # Image 5: Pure noise_1
        x_0[4] = noise_1
        return x_0

    def vertical_splice(noise_0, noise_1):
        # Initialize x0
        x_0 = torch.zeros([batch_size, 3, sample_resolution, sample_resolution], dtype=torch.float32, device=device)
        # Image 1: Pure noise_0
        x_0[0] = noise_0
        # Image 2: 75% noise_0, 25% noise_1
        x_0[1] = noise_0.clone()
        x_0[1, :, :, (sample_resolution * 3) // 4:] = noise_1[:, :, :, (sample_resolution * 3) // 4:]
        # Image 3: 50% noise_0, 50% noise_1
        x_0[2] = noise_0.clone()
        x_0[2, :, :, sample_resolution // 2:] = noise_1[:, :, :, sample_resolution // 2:]
        # Image 4: 25% noise_0, 75% noise_1
        x_0[3] = noise_1.clone()
        x_0[3, :, :, :sample_resolution // 4] = noise_0[:, :, :, :sample_resolution // 4]
        # Image 5: Pure noise_1
        x_0[4] = noise_1
        return x_0

    # Sample 2 noise latents.
    noise_0 = torch.randn([1, 3, sample_resolution, sample_resolution], dtype=torch.float32, device=device)
    noise_1 = torch.randn([1, 3, sample_resolution, sample_resolution], dtype=torch.float32, device=device)

    if direction == 'horizontal':
        x_0 = horizontal_splice(noise_0, noise_1)
    elif direction == 'vertical':
        x_0 = vertical_splice(noise_0, noise_1)
    else:
        raise ValueError("Direction must be either 'horizontal' or 'vertical'.")

    if verbose > 0:
        # Visualize the latent noise
        plt.figure(figsize=(14, 4))

        for i in range(batch_size):
            plt.subplot(1, batch_size, i + 1)
            img = x_0[i].cpu().permute(1, 2, 0).numpy() # Convert tensor to numpy for visualization
            img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0,1] for visualization
            plt.imshow(img)

            # Add boundary lines
            if direction == 'horizontal':
                if i == 1:  # 75% N1, 25% N2
                    plt.axhline(y=(sample_resolution * 3) // 4, color='white', linestyle='--', linewidth=4, label='Boundary between noise patterns')
                elif i == 2:  # 50% N1, 50% N2
                    plt.axhline(y=sample_resolution // 2, color='white', linestyle='--', linewidth=4)
                elif i == 3:  # 25% N1, 75% N2
                    plt.axhline(y=sample_resolution // 4, color='white', linestyle='--', linewidth=4)
            elif direction == 'vertical':
                if i == 1:  # 75% N1, 25% N2
                    plt.axvline(x=(sample_resolution * 3) // 4, color='white', linestyle='--', linewidth=4, label='Boundary between noise patterns')
                elif i == 2:  # 50% N1, 50% N2
                    plt.axvline(x=sample_resolution // 2, color='white', linestyle='--', linewidth=4)
                elif i == 3:  # 25% N1, 75% N2
                    plt.axvline(x=sample_resolution // 4, color='white', linestyle='--', linewidth=4)

            if i == 0:
                plt.title("100% Noise 1")
            elif i == 1:
                plt.title("75% N1, 25% N2")
            elif i == 2:
                plt.title("50% N1, 50% N2")
            elif i == 3:
                plt.title("25% N1, 75% N2")
            elif i == 4:
                plt.title("100% Noise 2")

            plt.axis('off')

        legend = plt.figlegend(loc='lower center', fontsize=14, frameon=True)
        # Dark background to make white line visible
        frame = legend.get_frame()
        frame.set_facecolor('darkgray')
        frame.set_edgecolor('black')

        plt.suptitle("Noise Latents (to be passed through Flow Matching)", fontsize=24)

        plt.tight_layout()
        plt.subplots_adjust(bottom=0.14)  # Make room for the legend and title
        plt.show()

    labels = torch.tensor(list(range(batch_size)), dtype=torch.int32,
                          device=device)  # Required to run the model, but not considered.
    time_grid = torch.linspace(0, 1, 10).to(device=device)
    synthetic_samples = solver.sample(
        time_grid=time_grid,
        x_init=x_0,
        method=args_dict['ode_method'],
        atol=args_dict['ode_options']['atol'] if 'atol' in args_dict['ode_options'] else None,
        rtol=args_dict['ode_options']['rtol'] if 'rtol' in args_dict['ode_options'] else None,
        step_size=args_dict['ode_options']['step_size'] if 'step_size' in args_dict['ode_options'] else None,
        cfg_scale=args_dict['cfg_scale'],
        label=labels,
        return_intermediates=False,
    )
    # Scaling to [0, 1] from [-1, 1]
    synthetic_samples = torch.clamp(
        synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0
    )
    synthetic_samples = torch.floor(synthetic_samples * 255) / 255.0

    output_path = os.path.join("output_dir", "generated_examples_CIFAR10")
    os.makedirs(output_path, exist_ok=True)
    plt.figure(figsize=(15, 3))
    for j in range(batch_size):
        plt.subplot(1, batch_size, j + 1)  # 1 row, batch_size columns
        image = synthetic_samples[j].cpu().permute(1, 2, 0).numpy()
        plt.imshow(image)
        plt.axis('off')

    time_value = time_grid[-1].item()  # Use the last time step (index 9)
    plt.suptitle(f'Images generated from Noise', fontsize=24)
    plt.tight_layout()

    # Save as high-quality PDF in the specified directory
    if direction == 'horizontal':
        filename = f'Spatial_Mutual_Info_Horizontal_{filename}.pdf'
    elif direction == 'vertical':
        filename = f'Spatial_Mutual_Info_Vertical_{filename}.pdf'
    else:
        raise ValueError("Direction must be either 'horizontal' or 'vertical'.")
    full_path = os.path.join(output_path, filename)
    plt.savefig(full_path, format='pdf', bbox_inches='tight')
    plt.show()

    print(f"Generated Using Seed: {seed}")


import warnings

# Suppress specific FutureWarning before running your code
warnings.filterwarnings("ignore",
                       message=".*torch.cuda.amp.autocast.*",
                       category=FutureWarning)

batch_size = 5


spatial_mutual_noise(batch_size=5, verbose=1, direction='horizontal', filename=0)
spatial_mutual_noise(batch_size=5, verbose=0, direction='horizontal', filename=1)
spatial_mutual_noise(batch_size=5, verbose=0, direction='horizontal', filename=2)
spatial_mutual_noise(batch_size=5, verbose=0, direction='horizontal', filename=3)
spatial_mutual_noise(batch_size=5, verbose=0, direction='horizontal', filename=4)

spatial_mutual_noise(batch_size=5, verbose=1, direction='vertical', filename=0)
spatial_mutual_noise(batch_size=5, verbose=0, direction='vertical', filename=1)
spatial_mutual_noise(batch_size=5, verbose=0, direction='vertical', filename=2)
spatial_mutual_noise(batch_size=5, verbose=0, direction='vertical', filename=3)
spatial_mutual_noise(batch_size=5, verbose=0, direction='vertical', filename=4)








# recovered_noise = solver.sample(
#     time_grid=torch.tensor([1.0, 0.0], device=device),
#     x_init=synthetic_samples,  # start at t=1 (the generated image)
#     method=args_dict['ode_method'],
#     atol=args_dict['ode_options']['atol'] if 'atol' in args_dict['ode_options'] else None,
#     rtol=args_dict['ode_options']['rtol'] if 'rtol' in args_dict['ode_options'] else None,
#     step_size=args_dict['ode_options']['step_size'] if 'step_size' in args_dict['ode_options'] else None,
#     label=labels,
#     cfg_scale=args_dict['cfg_scale'],
# )

# from torch.distributions.normal import Normal
# gaussian_log_density = Normal(torch.zeros(size=[3, 32, 32], device=device), torch.ones(size=[3, 32, 32], device=device)).log_prob
# x_0_recovered, log_p1 = solver.compute_likelihood(
#     x_1=synthetic_samples,
#     log_p0=gaussian_log_density,
#     time_grid=torch.tensor([1.0, 0.0], device=device),
#     method=args_dict['ode_method'],
#     step_size=args_dict['ode_options']['step_size'] if 'step_size' in args_dict['ode_options'] else None,
#     label=labels,
#     cfg_scale=args_dict['cfg_scale']
#     )




