import os
from pathlib import Path
import json

from models.model_configs import instantiate_model
import torch
from training.eval_loop import CFGScaledModel
from flow_matching.solver.ode_solver import ODESolver
from matplotlib import pyplot as plt
from dataset_generation.mutual_information import create_mutual_info_covariance, modify_covariance_block


checkpoint_path = Path("./output_dir/checkpoint-cond-699.pth")
args_filepath = checkpoint_path.parent / 'args.json'
with open(args_filepath, 'r') as f:
    args_dict = json.load(f)

model = instantiate_model(architechture=args_dict['dataset'], is_discrete='discrete_flow_matching' in args_dict  and args_dict['discrete_flow_matching'],
                          use_ema=args_dict['use_ema'])
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.train(False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Number of GPUs being used:", torch.cuda.device_count())
model.to(device=device)

cfg_weighted_model = CFGScaledModel(model=model)

solver = ODESolver(velocity_model=cfg_weighted_model)
ode_opts = args_dict['ode_options']
ode_opts["method"] = args_dict['ode_method']

# Set the sampling resolution corresponding to the model
if 'train_blurred_64' in args_dict['data_path'] and args_dict['dataset'] == 'imagenet':
    sample_resolution = 64
elif 'train_blurred_32' in args_dict['data_path'] or args_dict['dataset'] == 'cifar10':
    sample_resolution = 32



def gaussian_mutual_noise(num_corr_images=2, channels=3, cov_images=None, cov_theta=None, theta_var=1, labels=None, epsilon=0, verbose=0, seed=None, filename=0):
    """
    This function generates a batch of correlated images from a multivariate Gaussian distribution where the
    off-diagonal elements of the covariance matrix are controlled by a coefficient, rho.
    :param num_corr_images: Number of correlated images to generate.
    :param rho: Correlation coefficient between noise vectors (0 = independent, 0.999 = perfectly correlated)
    :param verbose: If > 0, displays generate noise latents.
    :param seed: Random seed for reproducibility.
    :param filename: Suffix for the output filename.
    :return: None. The function:
             1. Creates correlated noise latents from a multivariate Gaussian distribution.
             2. Generates images using flow matching from these noise latents.
    """

    if seed is None:
        seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
    torch.manual_seed(seed)

    dim = channels * sample_resolution * sample_resolution  # 3072D per image
    # dim = 3

    mu, Sigma = create_mutual_info_covariance(
        num_corr_images=num_corr_images,
        dim=dim,  # For illustration, each image is 3-dimensional.
        cov_images=cov_images,
        cov_theta=cov_theta,
        theta_var=theta_var,
        epsilon=epsilon
    )
    #
    # print(Sigma)
    #
    # # Sigma = modify_covariance_block(Sigma, block_i=0, block_j=1, constant=0.001, dim=dim, device=device)
    # print("-------------------------------")
    # print(Sigma)




    # Sample from the multivariate normal distribution
    mvn = torch.distributions.MultivariateNormal(mu, covariance_matrix=Sigma)
    samples = mvn.sample()

    # Extract and reshape multiple images
    noise_vectors = []
    for i in range(num_corr_images):
        noise_flat = samples[i * dim:(i + 1) * dim]
        noise = noise_flat.reshape(channels, sample_resolution, sample_resolution)
        noise_vectors.append(noise)

    # Extract Theta (last element of the sample)
    theta_value = samples[-1].item()


    # Stack all noise tensors into a batch
    x_0 = torch.stack(noise_vectors, dim=0)


    if verbose > 0:
        # Visualize the latent noise
        plt.figure(figsize=(14, 4))

        for i in range(num_corr_images):
            plt.subplot(1, num_corr_images, i + 1)
            img = x_0[i].cpu().permute(1, 2, 0).numpy() # Convert tensor to numpy for visualization
            img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0,1] for visualization
            plt.imshow(img)

            plt.title(f"Noise {i + 1}")
            plt.axis('off')

        plt.suptitle("Noise Latents (to be passed through Flow Matching)", fontsize=24)

        plt.tight_layout()
        plt.subplots_adjust(bottom=0.14)  # Make room for the legend and title
        plt.show()


    if labels is None:
        labels = list(range(num_corr_images))

    labels = torch.tensor(labels, dtype=torch.int32, device=device)


    time_grid = torch.linspace(0, 1, 10).to(device=device)
    synthetic_samples = solver.sample(
        time_grid=time_grid,
        x_init=x_0,
        method=args_dict['ode_method'],
        atol=args_dict['ode_options']['atol'] if 'atol' in args_dict['ode_options'] else None,
        rtol=args_dict['ode_options']['rtol'] if 'rtol' in args_dict['ode_options'] else None,
        step_size=args_dict['ode_options']['step_size'] if 'step_size' in args_dict['ode_options'] else None,
        cfg_scale=args_dict['cfg_scale'],
        label=labels,
        return_intermediates=False,
    )
    # Scaling to [0, 1] from [-1, 1]
    synthetic_samples = torch.clamp(
        synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0
    )
    synthetic_samples = torch.floor(synthetic_samples * 255) / 255.0

    output_path = os.path.join("output_dir", "generated_examples_CIFAR10")
    os.makedirs(output_path, exist_ok=True)
    plt.figure(figsize=(15, 3))
    for j in range(num_corr_images):
        plt.subplot(1, num_corr_images, j + 1)  # 1 row, num_corr_images columns
        image = synthetic_samples[j].cpu().permute(1, 2, 0).numpy()
        plt.imshow(image)
        plt.axis('off')

    # time_value = time_grid[-1].item()  # Use the last time step (index 9)

    plt.suptitle(
        r'Images Generated (Cov(X,Y) = ' + f'{cov_images}, ' +
        r'Cov($\theta$) = ' + f'{cov_theta}' + ')',
        fontsize=20
    )

    plt.tight_layout()

    # Save as high-quality PDF in the specified directory
    filename = f'Gaussian_Mutual_Info_Horizontal_{filename}.pdf'
    full_path = os.path.join(output_path, filename)
    plt.savefig(full_path, format='pdf', bbox_inches='tight')
    plt.show()

    print (f"Theta Value: {theta_value}")
    print(f"Generated Using Seed: {seed}")



import warnings

# Suppress specific FutureWarning before running your code
warnings.filterwarnings("ignore",
                       message=".*torch.cuda.amp.autocast.*",
                       category=FutureWarning)

# a necessary condition for positive definiteness is:
#     theta_var > (dim * rho^2) / diag, where dim = num_corr_images * pixels in image, rho is the correlation coefficient,
#     and diag is the diagonal of the covariance matrix (which is usually 1 but can be other values).

# gaussian_mutual_noise(num_corr_images=2, rhos=[0.9, 0.9], variances=[1, 1], theta_var=4978, epsilon=0, type='x_y_indep', verbose=1, filename=0)
# gaussian_mutual_noise(num_corr_images=2, rhos=[0.1, 0.1], variances=[1, 1], theta_var=4978, epsilon=0, type='x_y_indep', verbose=1, filename=0)


gaussian_mutual_noise(num_corr_images=2, cov_theta=[0, 0], theta_var=3318, epsilon=0.0, verbose=1, filename=0,
                      cov_images=[[1.0, 0.999],
                                  [0.999, 1.0]],
                      labels=[6, 6],
                      seed = 0
                      )

gaussian_mutual_noise(num_corr_images=2, cov_theta=[0.9, 0.9], theta_var=3318, epsilon=0.0, verbose=1, filename=0,
                      cov_images=[[1.0, 0.999],
                                  [0.999, 1.0]],
                      labels=[6, 6],
                      seed = 0
                      )

gaussian_mutual_noise(num_corr_images=2, cov_theta=[0, 0], theta_var=3318, epsilon=0.0, verbose=1, filename=0,
                      cov_images=[[1.0, 0.999],
                                  [0.999, 1.0]],
                      labels=[1, 6],
                      seed = 0
                      )

gaussian_mutual_noise(num_corr_images=2, cov_theta=[0.9, 0.9], theta_var=3318, epsilon=0.0, verbose=1, filename=0,
                      cov_images=[[1.0, 0.999],
                                  [0.999, 1.0]],
                      labels=[1, 6],
                      seed = 0
                      )

# gaussian_mutual_noise(num_corr_images=2, cov_theta=[0, 0], theta_var=60000000, epsilon=0, verbose=1, filename=0,
#                       cov_images=[[1.0, 0.999],
#                                   [0.999, 1.0]]
#                       )

# gaussian_mutual_noise(num_corr_images=2, cov_theta=[0.9, 0.9], theta_var=3318, epsilon=0.0, verbose=1, filename=0, ## Proof that the math checks out.
#                       cov_images=[[1, 0.5],
#                                   [0.5, 1.0]]
#                       )


