from plot_utils import *

import argparse
import random
from statistics import mean, median
from copy import deepcopy

import torch
import torchvision.transforms as transforms
from torchrl.data import ReplayBuffer, LazyMemmapStorage, PrioritizedSampler

# import clip

import os

import utils
from utils import *

from buffer import StatesReplayBuffer
from langevin import langevin_dynamics
from models import GFN

import gflownet_losses
from energies import *
from evaluations import *

import shutil
import cv2
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FFMpegWriter
from tqdm import trange
import wandb

from PIL import Image

parser = argparse.ArgumentParser(description="GFN Linear Regression")
parser.add_argument("--eval_every_n_steps", type=int, default=1000)

parser.add_argument("--lr_policy", type=float, default=1e-3)
parser.add_argument("--lr_flow", type=float, default=1e-2)
parser.add_argument("--lr_back_multiplier", type=float, default=1e-1)
parser.add_argument("--hidden_dim", type=int, default=64)
parser.add_argument("--s_emb_dim", type=int, default=64)
parser.add_argument("--t_emb_dim", type=int, default=64)
parser.add_argument("--harmonics_dim", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--epochs", type=int, default=25000)
parser.add_argument("--buffer_traj_size", type=int, default=5000)
parser.add_argument("--buffer_size", type=int, default=300 * 1000 * 2)
parser.add_argument("--T", type=int, default=10)
parser.add_argument("--subtb_lambda", type=int, default=2)
parser.add_argument("--t_scale", type=float, default=5.0)
parser.add_argument("--log_var_range", type=float, default=4.0)
parser.add_argument(
    "--energy",
    type=str,
    default="rgmm",
    choices=(
        "rgmm",
        "2gmm",
        "2cgmm",
        "9gmm",
        "25gmm",
        "40gmm",
        "distorted_gmm",
        "easy_funnel",
        "hard_funnel",
        "gan_ffhq",
        "many_well",
        "distorted_many_well",
    ),
)
parser.add_argument("--distortion_coef", type=float, default=0.0)
parser.add_argument("--data_dim", type=int, default=2)
parser.add_argument("--gan_magic_const", default=1, type=float)

parser.add_argument("--gan_prompt", type=str, default="A person with medium length hair")
parser.add_argument("--dataset_path", type=str)

parser.add_argument("--weights_path", type=str)

parser.add_argument("--is_dcgan", action="store_true", default=False)

parser.add_argument("--pf_mode_fwd", type=str, default="tb", choices=("tb", "tb_avg", "db", "subtb", "pis"))
parser.add_argument("--pb_mode_fwd", type=str, default="", choices=("no_update", "tb", "tlm", "tb_avg", "db", "mle"))
parser.add_argument("--pf_mode_bwd", type=str, default="tb", choices=("tb", "tb_avg", "db", "subtb", "pis"))
parser.add_argument("--pb_mode_bwd", type=str, default="", choices=("no_update", "tb", "tlm", "tb_avg", "db", "mle"))
parser.add_argument("--both_ways", action="store_true", default=False)
parser.add_argument("--bwd", action="store_true", default=False)
parser.add_argument("--learn_pb", action="store_true", default=False)
parser.add_argument("--huber_loss_quantile", type=float, default=1.0)
parser.add_argument(
    "--process_param",
    type=str,
    default="standard",
    choices=(
        "standard",
        "log_variance",
    ),
)

parser.add_argument("--gamma", type=float, default=1)
parser.add_argument("--pb_scale_range", type=float, default=0.9)
parser.add_argument("--pb_scale_policy", type=str, default="const", choices=("const", "linear"))
parser.add_argument("--learned_variance", action="store_true", default=False)
parser.add_argument("--replay_ratio_n", type=int, default=0)
parser.add_argument("--tau", type=float, default=0.05)
parser.add_argument("--target_pf", action="store_true", default=False)
parser.add_argument("--target_pb", action="store_true", default=False)
parser.add_argument("--share_backbone", action="store_true", default=False)
parser.add_argument("--use_2optimizers", action="store_true", default=False)

parser.add_argument("--clip_grad_norm", type=float, default=np.inf)
parser.add_argument("--clip_grad_quantile", type=float, default=0.95)
parser.add_argument("--grad_history_window_size", type=int, default=1000)


# For local search
################################################################

parser.add_argument("--local_search", action="store_true", default=False)

# How many iterations to run local search
parser.add_argument("--max_iter_ls", type=int, default=200)

# How many iterations to burn in before making local search
parser.add_argument("--burn_in", type=int, default=100)

# How frequently to make local search
parser.add_argument("--ls_cycle", type=int, default=100)
# langevin step size
parser.add_argument("--ld_step", type=float, default=0.001)
parser.add_argument("--ld_schedule", action="store_true", default=False)

# target acceptance rate
parser.add_argument("--target_acceptance_rate", type=float, default=0.574)


# For replay buffer
################################################################
# high beta give steep priorization in reward prioritized replay sampling
parser.add_argument("--beta", type=float, default=1.0)

# low rank_weighted give steep priorization in rank-based replay sampling
parser.add_argument("--rank_weight", type=float, default=1e-2)

# three kinds of replay training: random, reward prioritized, rank-based
parser.add_argument("--prioritized", type=str, default="rank", choices=("none", "reward", "rank"))


################################################################

parser.add_argument("--exploratory", action="store_true", default=False)
parser.add_argument("--sampling", type=str, default="buffer", choices=("sleep_phase", "energy", "buffer"))
parser.add_argument("--langevin", action="store_true", default=False)
parser.add_argument("--langevin_scaling_per_dimension", action="store_true", default=False)
parser.add_argument("--conditional_flow_model", action="store_true", default=False)
parser.add_argument("--partial_energy", action="store_true", default=False)
parser.add_argument("--exploration_factor", type=float, default=0.1)
parser.add_argument("--exploration_wd", action="store_true", default=False)
parser.add_argument("--clipping", action="store_true", default=False)
parser.add_argument("--lgv_clip", type=float, default=1e2)
parser.add_argument("--gfn_clip", type=float, default=1e4)
parser.add_argument("--zero_init", action="store_true", default=False)
parser.add_argument("--pis_architectures", action="store_true", default=False)
parser.add_argument("--lgv_layers", type=int, default=3)
parser.add_argument("--joint_layers", type=int, default=2)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--weight_decay", type=float, default=1e-7)
parser.add_argument("--use_weight_decay", action="store_true", default=False)


parser.add_argument(
    "--discretizer",
    type=str,
    default="random",
    choices=(
        "random_discretizer",
        "uniform_discretizer",
        "harmonic_discretizer",
        "sqrt_harmonic_discretizer",
        "first_big_step_discretizer",
        "low_discrepancy_discretizer",
        "low_discrepancy2_discretizer",
        "equidistant_discretizer",
    ),
)
parser.add_argument("--discretizer_max_ratio", type=float, default=10.0)
parser.add_argument("--traj_length_strategy", type=str, default="static", choices=("static", "dynamic"))
parser.add_argument("--min_traj_length", type=int, default=10)
parser.add_argument("--max_traj_length", type=int, default=100)


parser.add_argument("--eval", action="store_true", default=False)
parser.add_argument("--model_it", type=int, default=25000)
parser.add_argument("--perfect", action="store_true", default=False)


# For visualization
################################################################

parser.add_argument("--seeds", type=int, default=1)
parser.add_argument("--granularity_n", type=int, default=11)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--version", type=int, default=1)
parser.add_argument("--visualization_type", default="new_gen", choices=("new_gen", "new_noised", "new-grid", "quiver", "streamplot"))

parser.add_argument("--eval_data_size", type=int, default=2048)
parser.add_argument("--plot_data_size", type=int, default=2048)
parser.add_argument("--final_eval_data_size", type=int, default=2048)
parser.add_argument("--final_plot_data_size", type=int, default=2048)


parser.add_argument("--wandb", action="store_true", default=False)

args = parser.parse_args()

eval_data_size = args.eval_data_size
plot_data_size = args.plot_data_size
final_eval_data_size = args.final_eval_data_size
final_plot_data_size = args.final_plot_data_size

set_seed(args.seed)
if "SLURM_PROCID" in os.environ:
    args.seed += int(os.environ["SLURM_PROCID"])

if args.pis_architectures:
    args.zero_init = True

device = torch.device("cuda" if torch.cuda.is_available() and not args.debug else "cpu")
coef_matrix = cal_subtb_coef_matrix(args.subtb_lambda, args.T).to(device)

if args.both_ways and args.bwd:
    args.bwd = False

if args.local_search:
    args.both_ways = True


def get_energy():
    if args.energy == "rgmm":
        energy = RandomGaussianMixture(device=device)
    elif args.energy == "2gmm":
        energy = TwoGaussianMixture(device=device)
    elif args.energy == "2cgmm":
        energy = TwoCloseGaussianMixture(device=device)
    elif args.energy == "9gmm":
        energy = NineGaussianMixture(device=device)
    elif args.energy == "25gmm":
        energy = GaussianMixture(device=device)
    elif args.energy == "40gmm":
        energy = FortyGaussianMixture(dim=args.data_dim, device=device)
    elif args.energy == "distorted_gmm":
        energy = DistortedGaussianMixture(device=device, distortion_coef=args.distortion_coef, dim=args.data_dim)
    elif args.energy == "hard_funnel":
        energy = HardFunnel(device=device)
    elif args.energy == "easy_funnel":
        energy = EasyFunnel(device=device)
    elif args.energy == "gan_ffhq":
        energy = FFHQGANLatent(get_name(args), args.dataset_path, args.weights_path, args.gan_prompt, args.gan_magic_const)
    elif args.energy == "many_well":
        energy = ManyWell(device=device)
    elif args.energy == "distorted_many_well":
        energy = DistortedManyWell(device=device, distortion_coef=args.distortion_coef, dim=args.data_dim)
    return energy


def eval_step(eval_data, energy, gfn_model, discretizer, final_eval=False):
    lambda_discretizer = lambda bsz: discretizer(bsz, args.T)
    gfn_model.eval()
    metrics = dict()
    if args.perfect:
        samples = energy.sample(eval_data_size).to(device)
    else:
        log_dir = "final_eval" if final_eval else "eval"
        data_size = final_eval_data_size if final_eval else eval_data_size
        init_state = torch.zeros(data_size, energy.data_ndim).to(device)
        prefix_name = f"{log_dir}/{discretizer.__name__}"
        samples, metrics[f"{prefix_name}_log_Z"], metrics[f"{prefix_name}_log_Z_lb"], metrics[f"{prefix_name}_log_Z_learned"] = (
            get_forward_trajectory_metrics(init_state, gfn_model, energy, lambda_discretizer)
        )

        if energy.access_to_gt_samples:
            metrics[f"{prefix_name}_mean_log_likelihood"], metrics[f"{prefix_name}_eubo"] = get_backward_trajectory_metrics(
                eval_data, gfn_model, energy, lambda_discretizer
            )

    if energy.compute_distribution_distances:
        metrics.update(get_sample_metrics(samples, eval_data, final_eval))

    if "gmm" in args.energy:
        gm_metrics, gm = add_gaussian_metrics(samples, energy, metrics, args.wandb)
        metrics.update(gm_metrics)
    else:
        gm = None

    # print(metrics)

    # if not args.perfect and args.energy in ["gan_ffhq"]:
    #     model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
    #     with torch.no_grad():
    #         objects = energy.generate(samples)
    #     clip_features = []
    #     for gen_img in objects:
    #         image = preprocess_clip(transforms.functional.to_pil_image(gen_img)).unsqueeze(0).to(device)
    #         with torch.no_grad():
    #             image_features = model_clip.encode_image(image)
    #         clip_features.append(image_features)
    #     clip_features = torch.cat(clip_features, dim=0)

    #     clip_features_norm = clip_features / clip_features.norm(dim=1)[:, None]
    #     pairwise_similarities = torch.mm(clip_features_norm, clip_features_norm.transpose(0, 1))
    #     average_distance = (1 - pairwise_similarities).sum() / pairwise_similarities.shape[0] / (pairwise_similarities.shape[0] - 1)
    #     # print(average_distance)

    #     metrics[f"{prefix_name}_CLIP_diversity"] = average_distance.item()

    return samples, metrics, gm


def train_step(
    energy,
    gfn_model,
    target_gfn_model,
    gfn_optimizer,
    it,
    exploratory,
    buffer_traj,
    buffer,
    buffer_ls,
    exploration_factor,
    exploration_wd,
    grad_history_window,
):
    clip_losses, clip_grad_norms, grad_norms = [], [], []
    gfn_model.train()
    is_fwd = None
    metrics = {}

    discretizer = getattr(utils, args.discretizer)
    lambda_discretizer = lambda bsz: discretizer(bsz, args.T)

    for grad_it in range(args.replay_ratio_n + 1):

        gfn_model.zero_grad()

        exploration_std = get_exploration_std(it, exploratory, exploration_factor, exploration_wd)
        log_r = None
        # print(16 * "-")
        # print(f"{grad_it=}")
        if grad_it > 0:
            start_states = torch.zeros(args.batch_size, energy.data_ndim).to(device)
            (states, log_r), buffer_samples_info = buffer_traj.sample(return_info=True)
            if isinstance(states, list):
                states = [s.to(device) for s in states]
            else:
                states = states.to(device)
            log_r = log_r.to(device)
            is_fwd = False
            # print(f"{is_fwd=}")
        else:
            states = None
            if (args.both_ways and it % 2 == 0) or (not args.both_ways and not args.bwd):
                is_fwd = True
                start_states = torch.zeros(args.batch_size, energy.data_ndim).to(device)
            else:
                is_fwd = False
                if args.sampling == "sleep_phase":
                    start_states = gfn_model.sleep_phase_sample(args.batch_size, exploration_std).to(device)
                elif args.sampling == "energy":
                    start_states = energy.sample(args.batch_size).to(device)
                elif args.sampling == "buffer":
                    if args.local_search:
                        if it % args.ls_cycle < 2:
                            start_states, _ = buffer.sample()
                            # print(start_states)
                            local_search_samples, log_r = langevin_dynamics(start_states, energy.log_reward, device, args)
                            # print(local_search_samples)
                            buffer_ls.add(local_search_samples, log_r)
                            # buffer.add(local_search_samples, log_r)

                        start_states, log_r = buffer_ls.sample()
                        # start_states, _ = buffer.sample()
                    else:
                        start_states, log_r = buffer.sample()

        # print(is_fwd)
        pf_mode = args.pf_mode_fwd if is_fwd else args.pf_mode_bwd
        pb_mode = args.pb_mode_fwd if is_fwd else args.pb_mode_bwd
        # print(grad_it)
        # print(f"{pf_mode=}")
        # print(f"{pb_mode=}")
        pf_loss, pb_loss, states, log_pfs, log_pbs, log_fs, log_r = get_gfn_loss(
            is_fwd,
            args.learn_pb,
            getattr(gflownet_losses, pf_mode),
            getattr(gflownet_losses, pb_mode),
            start_states,
            gfn_model,
            target_gfn_model,
            energy,
            coef_matrix,
            lambda_discretizer,
            args.target_pf,
            args.target_pb,
            exploration_std,
            True,
            states,
            args.huber_loss_quantile,
            log_r,
            (pf_mode == "pis"),
        )
        loss = pf_loss + pb_loss

        loss.backward()
        squared_sum_grad_norm = 0
        for p in gfn_model.parameters():
            if p is not None and p.grad is not None:
                squared_sum_grad_norm += torch.norm(p.grad) ** 2
        grad_norm = np.sqrt(squared_sum_grad_norm.item())
        grad_norms.append(grad_norm)
        grad_history_window.append(grad_norm)

        clip_grad_norm = min(
            grad_norm,
            np.percentile(grad_history_window, 100 * args.clip_grad_quantile),
            args.clip_grad_norm,
        )
        torch.nn.utils.clip_grad_norm_(gfn_model.parameters(), max_norm=clip_grad_norm)
        clip_grad_norms.append(clip_grad_norm)
        if grad_norm > 0:
            clip_loss = (clip_grad_norm / grad_norm) * loss
        else:
            clip_loss = torch.tensor(0.0, device=device)
        # assert clip_loss <= loss, f"{clip_loss=} {loss=}\t{clip_grad_norm=} {grad_norm=}"
        clip_losses.append(clip_loss.item())

        if grad_it == 0:
            if isinstance(states, list):
                states = [s.detach().cpu() for s in states]
            else:
                states = states.detach().cpu()
            buffer_samples_idxs = buffer_traj.extend((states, log_r.detach().cpu()))
            buffer_traj.update_priority(buffer_samples_idxs, clip_loss)
            if is_fwd and args.sampling == "buffer":
                if isinstance(states, list):
                    buffer.add(states[0][:, -1], log_r)
                else:
                    buffer.add(states[:, -1], log_r)
        elif grad_it:
            buffer_traj.update_priority(buffer_samples_info["index"], clip_loss)

        gfn_optimizer.step()

        if args.target_pf or args.target_pb:
            update_target_network(gfn_model, target_gfn_model, args.tau)

        if grad_it == 0 and is_fwd:
            metrics.update(energy.get_train_metrics(log_r))

    gfn_optimizer.lr_scheduler_step()

    return loss, pf_loss, pb_loss, clip_losses, grad_norms, clip_grad_norms, is_fwd, metrics


def train():
    name = get_name(args)
    wandb_name = get_wandb_name(args)
    print(f"{name=}")
    print(f"{wandb_name=}")
    if not os.path.exists(name):
        os.makedirs(name)

    energy = get_energy()
    eval_data = energy.sample(eval_data_size).to(device)

    config = args.__dict__
    config["Experiment"] = f"{args.energy}"
    if args.wandb:
        wandb.init(project="GFN Energy", config=config, name=wandb_name)

    grad_history_window = []

    gfn_class = GFN

    cond_args = {}

    gfn_model = gfn_class(
        energy.data_ndim,
        args.s_emb_dim,
        args.hidden_dim,
        args.harmonics_dim,
        args.t_emb_dim,
        clipping=args.clipping,
        lgv_clip=args.lgv_clip,
        gfn_clip=args.gfn_clip,
        langevin=args.langevin,
        learned_variance=args.learned_variance,
        partial_energy=args.partial_energy,
        log_var_range=args.log_var_range,
        pb_scale_range=args.pb_scale_range,
        pb_scale_policy=args.pb_scale_policy,
        t_scale=args.t_scale,
        langevin_scaling_per_dimension=args.langevin_scaling_per_dimension,
        conditional_flow_model=args.conditional_flow_model,
        learn_pb=args.learn_pb,
        process_param=args.process_param,
        pis_architectures=args.pis_architectures,
        share_backbone=args.share_backbone,
        lgv_layers=args.lgv_layers,
        joint_layers=args.joint_layers,
        zero_init=args.zero_init,
        device=device,
        **cond_args,
    ).to(device)
    target_gfn_model = deepcopy(gfn_model)

    gfn_optimizer = CustomGFNOptimizer(
        gfn_model,
        args.lr_policy,
        args.lr_flow,
        args.lr_back_multiplier,
        args.learn_pb,
        args.conditional_flow_model,
        args.use_weight_decay,
        args.weight_decay,
        args.use_2optimizers,
        args.share_backbone,
        args.gamma,
    )

    print(gfn_model)

    metrics = dict()

    buffer_traj = ReplayBuffer(
        storage=LazyMemmapStorage(args.buffer_traj_size),
        sampler=PrioritizedSampler(
            max_capacity=args.buffer_traj_size,
            alpha=1,
            beta=0.1,
        ),
        batch_size=args.batch_size,
    )
    buffer = StatesReplayBuffer(
        args.buffer_size,
        device,
        energy.log_reward,
        args.batch_size,
        data_ndim=energy.data_ndim,
        beta=args.beta,
        rank_weight=args.rank_weight,
        prioritized=args.prioritized,
    )
    buffer_ls = StatesReplayBuffer(
        args.buffer_size,
        device,
        energy.log_reward,
        args.batch_size,
        data_ndim=energy.data_ndim,
        beta=args.beta,
        rank_weight=args.rank_weight,
        prioritized=args.prioritized,
    )
    gfn_model.train()
    metrics_history = {}
    for i in trange(args.epochs + 1):
        gfn_model.current_it = i
        if not args.perfect:
            loss, pf_loss, pb_loss, clip_losses, grad_norms, clip_grad_norms, is_fwd, train_metrics = train_step(
                energy,
                gfn_model,
                target_gfn_model,
                gfn_optimizer,
                i,
                args.exploratory,
                buffer_traj,
                buffer,
                buffer_ls,
                args.exploration_factor,
                args.exploration_wd,
                grad_history_window,
            )
            # print(f"{loss=} {pf_loss=} {pb_loss=}")
            traj_direction = "fwd" if is_fwd else "bwd"
            metrics[f"train/{traj_direction}_loss"] = loss
            metrics[f"train/{traj_direction}_pf_loss"] = pf_loss
            metrics[f"train/{traj_direction}_pb_loss"] = pb_loss

            def add_mean_max_metrics(prefix, values):
                metrics[f"{prefix}_mean"] = mean(values)
                metrics[f"{prefix}_max"] = max(values)

            add_mean_max_metrics(f"train/{traj_direction}_clip_loss", clip_losses)
            add_mean_max_metrics(f"train/{traj_direction}_grad_norm", grad_norms)
            add_mean_max_metrics(f"train/{traj_direction}_clip_grad_norm", clip_grad_norms)

            metrics.update(train_metrics)

        if i % args.eval_every_n_steps == 0:
            discretizer = getattr(utils, args.discretizer)

            samples, current_metrics, gm = eval_step(eval_data, energy, gfn_model, discretizer, final_eval=False)
            metrics.update(current_metrics)

            # if args.discretizer != "uniform_discretizer":
            #     samples, current_metrics_wuniform_discritezer, gm = eval_step(
            #         eval_data, energy, gfn_model, uniform_discretizer, final_eval=False
            #     )
            #     metrics.update(current_metrics_wuniform_discritezer)

            if "tb_avg" in [args.pf_mode_fwd, args.pb_mode_fwd, args.pf_mode_bwd, args.pb_mode_bwd]:
                keys = list(metrics.keys())
                for key in keys:
                    if "log_Z_learned" in key:
                        del metrics[key]

            for k, v in metrics.items():
                metrics_history[k] = metrics_history.get(k, []) + [v.item() if type(v) == torch.Tensor else v]
            os.makedirs(f"metrics/{name}", exist_ok=True)
            for k, v in metrics_history.items():
                np.save(f"metrics/{name}/{k.replace('/', '_')}.npy", np.array(v))

            if args.wandb:
                try:
                    if i % 1000 == 0:
                        lambda_discretizer = lambda bsz: discretizer(bsz, args.T)
                        images = plot_step(energy, gfn_model, name, args.perfect, plot_data_size, lambda_discretizer, device, gm)
                        metrics.update(images)
                    else:
                        visualizations_keys = [k for k in metrics.keys() if "visualization/" in k]
                        for visualizations_key in visualizations_keys:
                            metrics.pop(visualizations_key)
                except Exception as e:
                    print("Failed on plot_step")
                    print(e)
                plt.close("all")
            if args.wandb:
                wandb.log(metrics, step=i)
                # if i % 1000 == 0:
                if args.use_2optimizers:
                    wandb.log(
                        {
                            "pf_lr": gfn_optimizer.pf_lr_scheduler.get_last_lr()[0],
                            "pb_lr": gfn_optimizer.pb_lr_scheduler.get_last_lr()[0],
                        },
                        step=i,
                    )
                else:
                    wandb.log({"pf_lr": gfn_optimizer.lr_scheduler.get_last_lr()[0]}, step=i)



def eval():
    print(device)
    plt.rcParams.update(plt.rcParamsDefault)
    # matplotlib.use("ps")
    # plt.rc("text", usetex=True)
    plt.rc("text.latex", preamble=r"\usepackage[dvipsnames]{xcolor}")
    plt.rc("text.latex", preamble=r"\usepackage{amsmath,amssymb}")
    plt.rcParams.update({"font.size": 40})

    print("Evaluating script is running...")
    energy = get_energy()

    # print("Sampling eval_data...")
    # eval_data = energy.sample(eval_data_size).to(device)

    if args.energy == "2gmm":
        bounds = (-10.0, 10.0)
        n_contour_levels = 50
        ticks = [-5, 0, 5]
    else:
        bounds = energy.bounds
        print(f"{bounds}")
        n_contour_levels = 100
        ticks = torch.arange(-10, 10 + 0.1, 2.5)

    granularity_n = args.granularity_n
    assert granularity_n % 2 == 1
    xs = torch.linspace(bounds[0], bounds[1], granularity_n)
    ys = torch.linspace(bounds[0], bounds[1], granularity_n)
    x, y = torch.meshgrid(xs, ys)
    if "gmm" in args.energy:
        if "new" in args.visualization_type:
            figures = [(get_figures(n=1, m=args.seeds, bounds=bounds)) for _ in range(args.T + 1)]
        else:
            figures = [(get_figures(n=4, m=args.seeds, bounds=bounds)) for _ in range(args.T + 1)]
    else:
        figures = [(get_figures(n=1, m=2, bounds=bounds)) for _ in range(args.T + 1)]
    blue_vars, red_vars, wass_dists = (
        torch.zeros(args.seeds, args.T + 1, energy.data_ndim),
        torch.zeros(args.seeds, args.T + 1, energy.data_ndim),
        torch.zeros(args.seeds, args.T + 1),
    )

    dir_path = f"visualizations-bs={args.batch_size}/{args.energy}"
    os.makedirs(dir_path, exist_ok=True)

    file_path = f"{dir_path}/{args.seed=}-bs={args.batch_size}-{args.visualization_type=}-T={args.T}-pb_param={args.process_param}".replace(
        "args.", ""
    )

    discretizer = getattr(utils, args.discretizer)
    lambda_discretizer = lambda bsz: discretizer(bsz, args.T)

    for seed_i in trange(args.seeds):
        seed = args.seed + seed_i
        dir_path = get_name(args)
        # weights_path = f"{dir_path}model_final.pt"
        weights_path = f"{dir_path}model{args.model_it}.pt"
        # weights_path = get_name(args).replace(f"seed_{args.seed}", f"seed_{seed}") + "model.pt"
        # get_wandb_name(args).replace(f"seed_{args.seed}", f"seed_{seed}") + "model.pt"
        gfn_model = GFN(
            energy.data_ndim,
            args.s_emb_dim,
            args.hidden_dim,
            args.harmonics_dim,
            args.t_emb_dim,
            clipping=args.clipping,
            lgv_clip=args.lgv_clip,
            gfn_clip=args.gfn_clip,
            langevin=args.langevin,
            learned_variance=args.learned_variance,
            partial_energy=args.partial_energy,
            log_var_range=args.log_var_range,
            pb_scale_range=args.pb_scale_range,
            pb_scale_policy=args.pb_scale_policy,
            t_scale=args.t_scale,
            langevin_scaling_per_dimension=args.langevin_scaling_per_dimension,
            conditional_flow_model=args.conditional_flow_model,
            learn_pb=args.learn_pb,
            process_param=args.process_param,
            pis_loss=(args.pf_mode_fwd == "pis"),
            pis_architectures=args.pis_architectures,
            share_backbone=args.share_backbone,
            lgv_layers=args.lgv_layers,
            joint_layers=args.joint_layers,
            zero_init=args.zero_init,
            device=device,
        ).to(device)

        gfn_model.load_state_dict(torch.load(weights_path))
        gfn_model.eval()

        if energy.is_gan:
            n_sample = args.final_eval_data_size
            samples = []
            for _ in trange(args.final_eval_data_size // args.batch_size + 1):
                with torch.no_grad():
                    current_batch_size = min(args.batch_size, n_sample)
                    samples.append(gfn_model.sample(current_batch_size, lambda_discretizer, energy.log_reward).detach().cpu())
                    n_sample -= current_batch_size
            torch.save(torch.cat(samples), f"{dir_path}samples.pt")

        exploration_std = None
        # get_exploration_std(args.epochs, args.exploratory, args.exploration_factor, args.exploration_wd)

        if energy.is_gan:
            gt_states = energy.prior.rsample((args.batch_size,)).to(device)
        else:
            gt_states = energy.sample(args.batch_size).to(device)
        grid_states = torch.stack((x, y), -1).flatten(0, 1).to(device)
        # _, _, _, all_pf_mean, all_pflogvars, all_pblogvars, _, back_var_correction = gfn_model.predict(
        #     grid_states.unsqueeze(1).repeat(1, args.T + 1, 1), energy.log_reward
        # )
        dt = 1 / args.T

        if not "new" in args.visualization_type:
            # normalise
            min_value = -5
            all_pflogvars = torch.maximum(all_pflogvars, torch.full_like(all_pflogvars, min_value))
            all_pblogvars = torch.maximum(all_pblogvars, torch.full_like(all_pblogvars, min_value))

        with torch.no_grad():
            discretizer = getattr(utils, args.discretizer)
            lambda_discretizer = lambda bsz: discretizer(bsz, args.T)
            initial_states = torch.zeros(args.batch_size, energy.data_ndim).to(device)
            (
                generated_transitions,
                _,
                _,
                _,
                all_pf_mean_gen,
                all_pfvars_gen,
                _,
                _,
                min_logvar_generated,
                max_logvar_generated,
            ) = gfn_model.get_trajectory_fwd(initial_states, lambda_discretizer, exploration_std, energy.log_reward, return_min_max=True)

            noised_gt, _, _, _, all_pb_mean_noised, all_pbstds_noised, _, _, min_logvar_noised, max_logvar_noised = (
                gfn_model.get_trajectory_bwd(gt_states, lambda_discretizer, energy.log_reward, return_min_max=True)
            )
        # print(torch.abs(all_pb_mean_noised).mean(0), all_pbstds_noised.mean(0))
        np.set_printoptions(precision=5)
        print("!", file_path)
        with open(f"{file_path}.txt", "w") as fout:
            print(all_pf_mean_gen.shape, all_pfvars_gen.shape)
            print(all_pb_mean_noised.shape, all_pbstds_noised.shape)
            # print(all_pf_mean_gen.shape, all_pfvars_gen.shape)
            # print(all_pb_mean_noised.shape, all_pbstds_noised.shape)
            for time in range(args.T + 1):
                #  np.sqrt(dt) * (all_pflogvars_gen[i, time] / 2).exp()
                print(time, file=fout)
                # print("1.", dt * all_pf_mean_gen[:, time], file=fout)
                # print("2.", dt * all_pb_mean_noised[:, time], file=fout)
                # print("3.", dt * all_pfvars_gen[:, time], file=fout)
                # print("4.", dt * all_pbstds_noised[:, time], file=fout)
                print("1.", torch.abs(all_pf_mean_gen[:, time]).mean(dim=0), file=fout)
                print("2.", torch.abs(all_pb_mean_noised[:, time]).mean(dim=0), file=fout)
                print("3.", torch.abs(all_pfvars_gen[:, time]).mean(dim=0), file=fout)
                print("4.", torch.abs(all_pbstds_noised[:, time]).mean(dim=0), file=fout)
                print(file=fout)
                print(file=fout)
                # print(
                #     f"T={time}/{args.T}. "
                #     f"pf_mu: {torch.abs(dt * all_pf_mean_gen).mean(0)[time].cpu().numpy()} "
                #     f"+- {torch.abs(dt * all_pf_mean_gen).std(0)[time].cpu().numpy()}\t",
                #     #
                #     f"pf_std: {all_pfvars_gen.mean(0)[time].cpu().numpy()} "
                #     f"+- {all_pfvars_gen.std(0)[time].cpu().numpy()}\t",
                #     #
                #     f"pb_mu: {torch.abs(all_pb_mean_noised).mean(0)[time].cpu().numpy()} "
                #     f"+- {torch.abs(all_pb_mean_noised).std(0)[time].cpu().numpy()}\t",
                #     #
                #     f"pb_std: {all_pbstds_noised.mean(0)[time].cpu().numpy()} "
                #     f"+- {all_pbstds_noised.std(0)[time].cpu().numpy()}",
                #     file=fout,
                # )

        min_logvar = min(min_logvar_generated, min_logvar_noised)
        max_logvar = max(max_logvar_generated, max_logvar_noised)

        def get_colors(plogvars):
            normalize = lambda logvar: (logvar - min_logvar) / (max_logvar - min_logvar)

            colors = []
            for i in range(plogvars.shape[0]):
                if plogvars[i].min() < min_logvar or plogvars[i].max() > max_logvar:
                    colors.append(torch.tensor([0, 1, 0]))
                else:
                    colors.append(torch.tensor([normalize(plogvars[i, 0]), 0, normalize(plogvars[i, 1])]))
            return colors

        dot_size = 50

        # print(noised_gt[:, args.T, :])
        # print("---")
        # print(gt_states)

        for time in range(args.T + 1):
            print(f"!!{time=}")
            fig, axs = figures[time]

            if "gmm" in args.energy:
                if "new" in args.visualization_type:
                    if args.seeds == 1:
                        ax = axs
                    else:
                        ax = axs[seed_i]
                    ax.set_title(f"seed = {seed_i}")
                else:
                    if args.seeds == 1:
                        ax = axs
                    else:
                        ax = axs[:, seed_i]
            else:
                print("!")
                ax = axs
                ax[0].set_title(f"samples13")
                ax[0].set_xlabel(f"coordinate #1")
                ax[0].set_ylabel(f"coordinate #3")
                ax[1].set_title(f"samples23")
                ax[1].set_xlabel(f"coordinate #2")
                ax[1].set_ylabel(f"coordinate #3")

            print("?")
            # SUPER LOW QUALITY RESEARCH CODE! ACHTUNG.
            if "new" in args.visualization_type:
                if "gmm" in args.energy:
                    plot_contours(energy.log_reward, ax=ax, bounds=bounds, n_contour_levels=n_contour_levels, device=device)
                    coord_pairs = [[0, 1]]
                else:
                    coord_pairs = [[0, 2], [1, 2]]

                for coord_pair_idx, coord_pair in enumerate(coord_pairs):
                    print("??")
                    if args.visualization_type == "new_gen":
                        plot_samples(
                            generated_transitions[:, time, coord_pair],
                            ax=ax[coord_pair_idx],
                            color="b",
                            size=dot_size,
                            bounds=bounds,
                        )
                    elif args.visualization_type == "new_noised":
                        plot_samples(
                            noised_gt[:, args.T - time, coord_pair],
                            ax=ax[coord_pair_idx],
                            color="b",
                            size=dot_size,
                            bounds=bounds,
                        )

                    if args.visualization_type == "new_gen":
                        # TODO
                        x_coords = generated_transitions[:, time, coord_pair[0]].detach().cpu()
                        y_coords = generated_transitions[:, time, coord_pair[1]].detach().cpu()
                        x_shifts = all_pf_mean_gen[:, time, coord_pair[0]].detach().cpu()
                        y_shifts = all_pf_mean_gen[:, time, coord_pair[1]].detach().cpu()
                        # print(x_coords.shape, y_coords.shape, x_shifts.shape, y_shifts.shape)
                    elif args.visualization_type == "new_noised":
                        x_coords = noised_gt[:, args.T - time, coord_pair[0]].detach().cpu()
                        y_coords = noised_gt[:, args.T - time, coord_pair[1]].detach().cpu()
                        x_shifts = all_pb_mean_noised[:, args.T - time, coord_pair[0]].detach().cpu()
                        y_shifts = all_pb_mean_noised[:, args.T - time, coord_pair[1]].detach().cpu()
                    # print(flush=True)
                    # print(x_coords, flush=True)
                    # print(y_coords, flush=True)
                    # print(x_shifts, flush=True)
                    # print(y_shifts, flush=True)
                    ax[coord_pair_idx].quiver(
                        x_coords,
                        y_coords,
                        x_shifts,
                        y_shifts,
                        scale=1,
                        angles="xy",
                        units="xy",
                        scale_units="xy",
                        headwidth=3,  # Uniform width of arrowhead
                        headlength=5,  # Uniform length of arrowhead
                        headaxislength=4.5,  # Uniform axis length of arrowhead
                    )
                    for i in range(x_coords.shape[0]):
                        scale = np.sqrt(5.991)

                        # Width and height of the ellipse
                        if args.visualization_type == "new_gen":
                            std = all_pfvars_gen[i, time].detach().cpu()
                        elif args.visualization_type == "new_noised":
                            std = all_pbstds_noised[i, args.T - time].detach().cpu()
                        else:
                            std = np.sqrt(dt) * (all_pflogvars[i, time] / 2).exp()
                        width = 2 * scale * std[0]
                        height = 2 * scale * std[1]

                        skip_gen = time == args.T and args.visualization_type == "new_gen"
                        skip_noised = args.T - time <= 1 and args.visualization_type == "new_noised"
                        # print(std)
                        if not (skip_gen or skip_noised):
                            ellipse = matplotlib.patches.Ellipse(
                                xy=(x_coords[i] + x_shifts[i], y_coords[i] + y_shifts[i]),
                                width=width,
                                height=height,
                                angle=0,
                                edgecolor="blue",
                                fc="None",
                                lw=2,
                                # label="95% Confidence Ellipse",
                            )
                            ax[coord_pair_idx].add_patch(ellipse)

            else:
                x_coords = grid_states.reshape((granularity_n, granularity_n, 2))[..., 0].flatten(0, 1).cpu()
                y_coords = grid_states.reshape((granularity_n, granularity_n, 2))[..., 1].flatten(0, 1).cpu()
                x_shifts = dt * all_pf_mean[:, time, :].reshape((granularity_n, granularity_n, 2))[..., 0].flatten(0, 1)
                y_shifts = dt * all_pf_mean[:, time, :].reshape((granularity_n, granularity_n, 2))[..., 1].flatten(0, 1)

                fig.suptitle(
                    f"Time={time} (out of {args.T}). "
                    r"$\text{MIN} := \min(\min_{t, x, y}{\;\log \overset{\rightarrow}{\sigma}^{t}_{x,y}}, \min_{t, x, y}{\;\log \overset{\leftarrow}{\sigma}^{t}_{x,y}})$. "
                    # "\n"
                    r"$\text{MAX} := \max(\max_{t, x, y}{\;\log \overset{\rightarrow}{\sigma}^{t}_{x,y}}, \max_{t, x, y}{\;\log \overset{\leftarrow}{\sigma}^{t}_{x,y}})$. "
                    # "\n"
                    r"Normalized$(\log \sigma^{t}_{x,y}) := "
                    r"\frac{\log \sigma^{t}_{x,y} - \text{MIN}}"
                    r"{\text{MAX} - \text{MIN}}$. "
                )
                ax[0].set_title(f"seed={args.seed}\n" "Generation Field\n(the exact applied drift).")
                ax[1].set_title(
                    # f"seed={args.seed}\n"
                    "Noise on Generation.\n"
                    r"Normalized$(\log \overset{\rightarrow}{\sigma}^{t}_{x,y})$. "
                    r"blue $\updownarrow$. red $\leftrightarrow$"
                )
                ax[2].set_title(r"blue: (gen.) 0 $\rightarrow$ samples. red: (destr.) g.t. $\rightarrow$ 0.")
                ax[3].set_title(
                    "Noise on Destruction.\n"
                    r"Normalized$(\log \overset{\leftarrow}{\sigma}^{t}_{x,y})$. "
                    r"blue $\updownarrow$. red $\leftrightarrow$"
                )

                # add arrows
                assert grid_states.shape[0] == all_pf_mean.shape[0]
                # print("!", all_pf_mean.shape)
                if args.quiver:
                    ax[0].quiver(
                        x_coords,
                        y_coords,
                        x_shifts,
                        y_shifts,
                        scale=1,
                        angles="xy",
                        units="xy",
                        scale_units="xy",
                        headwidth=3,  # Uniform width of arrowhead
                        headlength=5,  # Uniform length of arrowhead
                        headaxislength=4.5,  # Uniform axis length of arrowhead
                    )
                else:
                    speed = (x_shifts**2 + y_shifts**2).sqrt().detach().cpu().numpy()
                    print(speed)
                    print(max_speed)
                    speed = np.clip(speed, a_min=None, a_max=max_speed.item()) / max_speed.item()

                    ax[0].streamplot(
                        xs.numpy(),
                        ys.numpy(),
                        dt * all_pf_mean[:, time, :].reshape((granularity_n, granularity_n, 2))[..., 0].numpy(),
                        dt * all_pf_mean[:, time, :].reshape((granularity_n, granularity_n, 2))[..., 1].numpy(),
                        linewidth=2,
                        arrowsize=10,
                        start_points=grid_states.reshape((granularity_n, granularity_n, 2)).flatten(0, 1).detach().cpu().numpy(),
                        color=speed,
                        cmap="autumn",
                    )

                # add contours
                plot_contours(energy.log_reward, ax=ax[2], bounds=bounds, n_contour_levels=n_contour_levels, device=device)
                # add transitions
                print("!!", generated_transitions.shape, noised_gt.shape)
                print(blue_vars[seed, time].shape, generated_transitions[:, time, :].var(0).shape)
                blue_vars[seed, time] = generated_transitions[:, time, :].var(0)
                red_vars[seed, time] = noised_gt[:, time, :].var(0)

                plot_samples(generated_transitions[:, time, :], ax=ax[2], color="b", size=dot_size, bounds=bounds)
                # plot_samples(noised_gt[:, time, :], ax=ax[2], color="r", size=dot_size, bounds=bounds)

                # add color
                if time == args.T:
                    pflogvars_color = torch.stack(
                        [
                            torch.zeros(granularity_n**2),
                            torch.zeros(granularity_n**2),
                            torch.zeros(granularity_n**2),
                        ],
                        dim=1,
                    )
                else:
                    pflogvars_color = torch.stack(get_colors(all_pflogvars[:, time, :]), dim=0)
                if time <= 1:
                    pblogvars_color = torch.stack(
                        [
                            torch.zeros(granularity_n**2),
                            torch.zeros(granularity_n**2),
                            torch.zeros(granularity_n**2),
                        ],
                        dim=1,
                    )
                else:
                    pblogvars_color = torch.stack(get_colors(all_pblogvars[:, time, :]), dim=0)

                ax[1].pcolormesh(x, y, pflogvars_color.reshape((granularity_n, granularity_n, 3)))
                ax[3].pcolormesh(x, y, pblogvars_color.reshape((granularity_n, granularity_n, 3)))
                # ax_contour_colored.pcolormesh(x, y, color.reshape((granularity_n, granularity_n, 3)))

                ax[0].tick_params(left=True, bottom=True)  # Enable ticks on ax[0]
                ax[2].tick_params(left=True, bottom=True)  # Enable ticks on ax[2]

                ax[0].set_xticks(ticks)  # Set specific x-ticks
                ax[0].set_yticks(ticks)  # Set specific y-ticks
                ax[0].grid(c="black", linewidth=0.5, linestyle="--", alpha=1)  # Configure the gridlines

                # Set the gridlines for ax[2]
                ax[2].set_xticks(ticks)  # Set specific x-ticks
                ax[2].set_yticks(ticks)  # Set specific y-ticks
                ax[2].grid(c="black", linewidth=0.5, linestyle="--", alpha=1)  # Configure the gridlines

        # wass_dists[i] = compute_distribution_distances(generated_transitions, noised_gt, False)["2-Wasserstein"]

    frames = []
    for time in range(args.T + 1):
        fig, axs = figures[time]
        fig.tight_layout(rect=[0, 0, 1, 1])
        import io

        buf = io.BytesIO()
        fig.savefig(buf)
        buf.seek(0)
        img = Image.open(buf)
        frames.append(img)

    # plt.close(fig)
    print("All frames are created.")

    final_frames = frames

    # Read the first image to get the size
    height, width = final_frames[0].height, final_frames[0].width
    # Initialize video writer
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # Use 'mp4v' for MP4 format
    FPS = 1
    video = cv2.VideoWriter(f"{file_path}.mp4", fourcc, FPS, (width, height))
    # Add final_frames to the video
    for frame in final_frames:
        video.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
    # Release the video writer
    video.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    if args.eval:
        eval()
    else:
        train()
