import os
from types import SimpleNamespace
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

from diffusion import create_diffusion
from mask_generator import VideoMaskGenerator
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

from metrics.metric_utils import visualization
from metrics import evaluate_model


from tqdm import tqdm
from joblib import Parallel, delayed
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from visualization import (
    generate_comparison_plots,
    generate_comparison_plots_per_window,
    plot_chunk,
)


from data_utils import create_timeseries_dataset
from utils import create_logger, load_checkpoint
from utils import EMA, instantiate_model
from unconditional.utils import sample_unconditional


def get_test_paths(args: SimpleNamespace):
    """
    Create directories and set paths for testing results based on the provided arguments.
    """
    test_dir = os.path.join(
        args.exp_dir,
        f"test_msl{args.metrics_seq_len}_ovrlapstride{args.overlapping_seqs_stride}_ddim{str(args.ddim).replace('ddim', '')}_ckpt{args.ckpt_num}",
    )
    os.makedirs(test_dir, exist_ok=True)

    generated_timefreq_path = os.path.join(
        args.exp_dir,
        f"generated_timefreq_overlapstride{args.overlapping_seqs_stride}_{args.ddim}_ckpt{args.ckpt_num}.pt",
    )
    plot_dir = os.path.join(
        test_dir,
        f"plots",
    )

    return test_dir, generated_timefreq_path, plot_dir


def main(args: SimpleNamespace):
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."
    torch.set_grad_enabled(False)
    args.device = "cuda"

    if args.metrics_seq_len is not None:
        assert (
            args.metrics_seq_len <= args.seq_len
        ), "metrics_seq_len must be less than or equal to seq_len"
    else:
        args.metrics_seq_len = args.seq_len


    test_dir, generated_timefreq_path, plot_dir = get_test_paths(args)
    score_file = os.path.join(test_dir, "metrics.txt")

    logger = create_logger(test_dir, filename="test.log")
    logger.info(f"Experiment directory: {test_dir}")
    args.logger = logger

    args.augment_data = False  # no data augmentation during testing
    timeseries_data, timefreq_data, tmps_covariate = create_timeseries_dataset(
        args, logger
    )

    freq_centers = timeseries_data.get_frequencies_centers()
    cov_adj = timeseries_data.get_covariate_adj()
    freq_adj = timeseries_data.get_frequency_adj()

    if hasattr(args,"n_samples_to_generate"):
        test_idxs = np.random.choice(
            np.arange(len(timefreq_data)), args.n_samples_to_generate, replace=False
        )
        timefreq_data = timefreq_data[test_idxs]
        tmps_covariate = tmps_covariate[test_idxs]
        logger.info(f"Using {len(timefreq_data)} samples for testing.")

    test_dataset = torch.utils.data.TensorDataset(timefreq_data, tmps_covariate)

    args.batch_size = min(
        args.batch_size, len(test_dataset)
    )  # Ensure batch size is not larger than dataset size

    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True,
        drop_last=False,
    )

    model, args = instantiate_model(
        args,
        timefreq_data.shape,
        tmps_covariate.shape,
        freq_centers=freq_centers,
        cov_adj=cov_adj,
        freq_adj=freq_adj,
    )
    checkpoint_path = os.path.join(args.checkpoint_dir, f"{args.ckpt_num}.pt")
    model, ema, _, _ = load_checkpoint(
        checkpoint_path=checkpoint_path,
        model=model,
        ema=EMA(model, decay=0.995) if args.load_ema else None,
        optimizer=None,
        scheduler=None,
    )
    if args.load_ema:
        assert ema is not None, "EMA model is not loaded, but load_ema is set to True."
        ema.apply_shadow()  # copy EMA weights to the model
        logger.info(f"Loaded EMA model from ckpt: {checkpoint_path}")
    else:
        logger.info(f"Loaded model from ckpt: {checkpoint_path}")

    model.eval()

    diff_vars = vars(args.__dict__["gaussian_diffusion"])
    diff_vars["num_frames"] = args.num_frames
    diff_vars["timestep_respacing"] = args.ddim

    diffusion = create_diffusion(
        **diff_vars,
        training=False,
    )  # default: 1000 steps, linear noise schedule

    generator = VideoMaskGenerator(
        (args.num_frames, args.input_size[0], args.input_size[1])
    )

    logger.info("Start testing with args:")
    logger.info(args)
    if not os.path.exists(generated_timefreq_path):
        logger.info(f"Total number of batches: {len(test_loader)}")
        ema.apply_shadow()  # use EMA model for sampling
        generated_timefreq_data, original_timefreq_data = sample_unconditional(
            model=model,
            data_loader=test_loader,
            args=args,
            mask_generator=generator,
            diffusion_sampler=diffusion,
        )
        ema.restore()  # restore original model parameters
        torch.save(generated_timefreq_data, generated_timefreq_path)
    else:
        logger.info("Loading generated data from file")
        generated_timefreq_data = torch.load(generated_timefreq_path)
        original_timefreq_data = timefreq_data

    fake_data = timeseries_data.get_timeseries_from_timefreq(
        generated_timefreq_data
    )  # B, L, K
    ori_data = timeseries_data.get_timeseries_from_timefreq(
        original_timefreq_data
    )  # (B, L, K)

    if args.metrics_seq_len < args.seq_len:

        def extract_smaller_window(x: torch.Tensor, window_size: int):
            B, L, K = x.shape
            assert window_size <= L

            sub_x = x[:-1, :window_size, :]
            # extract all the subwindows of size window_size from the last sequence of x
            for i in range(0, L - window_size + 1):
                sub_x = torch.cat((sub_x, x[-1, i : i + window_size, :].unsqueeze(0)))
            return sub_x

        fake_data = extract_smaller_window(fake_data, args.metrics_seq_len)
        ori_data = extract_smaller_window(ori_data, args.metrics_seq_len)

    # -------------------------
    # inside your main code:
    if not args.skip_plots:

        logger.info("Plotting the original and generated time series")
        os.makedirs(plot_dir, exist_ok=True)

        generate_comparison_plots_per_window(
            ts_generated=fake_data.numpy(),
            ts_original=ori_data.numpy(),
            feature_names=["feature_" + str(i) for i in range(ori_data.shape[-1])],
            plot_dir=plot_dir,
            n_jobs=16,
        )

        n_chunks_to_plot = 10
        idx_to_plot = np.random.choice(
            np.arange(fake_data.shape[0]), n_chunks_to_plot, replace=False
        )
        logger.info(f"Plotting {n_chunks_to_plot} random chunks of the time series")

        tasks = []
        for idx in idx_to_plot:
            tasks.append(
                (
                    fake_data[idx].cpu().numpy(),
                    ori_data[idx].cpu().numpy(),
                    os.path.join(plot_dir, f"chunk_{idx}.png"),
                )
            )
        plt.rcParams.update({"font.size": 32})

        Parallel(n_jobs=min(n_chunks_to_plot, 16), backend="loky")(
            delayed(plot_chunk)(fake_chunk, ori_chunk, plot_path)
            for fake_chunk, ori_chunk, plot_path in tqdm(tasks, desc="plotting chunks")
        )

        logger.info("Plotting 100 chunks all together as blob")
        n_chunks_to_plot = min(100, fake_data.shape[0])
        idx_to_plot = np.random.choice(
            np.arange(fake_data.shape[0]), n_chunks_to_plot, replace=False
        )
        plot_chunk(
            fake_chunk=fake_data[idx_to_plot].cpu().numpy(),
            ori_chunk=ori_data[idx_to_plot].cpu().numpy(),
            plot_path=os.path.join(plot_dir, "chunks_blobs.png"),
        )

        logger.info(f"Plotting chunks average of the time series")
        plot_chunk(
            fake_chunk=fake_data.mean(dim=0).cpu().numpy(),
            ori_chunk=ori_data.mean(dim=0).cpu().numpy(),
            plot_path=os.path.join(plot_dir, "chunks_mean.png"),
        )
        mpl.rcParams.update(mpl.rcParamsDefault)

    #################################################################################
    #                                Compute metrics                                #
    #################################################################################
    if args.skip_metrics:
        logger.info("Skipping metrics computation")
        return

    logger.info("Computing metrics")
    torch.set_grad_enabled(True)

    ori_data = ori_data.numpy().astype(np.float64)
    fake_data = fake_data.numpy().astype(np.float64)

    # random shuffle to make the data as iid as possible
    idx = np.random.permutation(ori_data.shape[0])
    ori_data = ori_data[idx]
    fake_data = fake_data[idx]

    for an in ["kernel", "tsne"]:
        visualization(
            ori_data=ori_data,
            generated_data=fake_data,
            analysis=an,
            compare=3000,
            save_dir=test_dir,
        )

    evaluate_model(
        ori_data=ori_data,
        fake_data=fake_data,
        metrics_iterations=args.metrics_iterations,
        score_file=score_file,
    )
