import torch
import numpy as np
from typing import List
import yaml
import copy
from ast import literal_eval
import ast
import torch.nn.functional as F
import torchvision.transforms as v2
from torchmetrics.functional.image import peak_signal_noise_ratio as PSNR
from ignite.metrics import SSIM
from collections import defaultdict
import math
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from torchvision.utils import save_image
from fld.metrics.FID import FID
from matplotlib import cm
import os
import warnings
import tqdm
import mlflow.pytorch as mlpt
import lpips
import ot
from torchdiffeq import odeint
import random
import torch.backends.cudnn as cudnn
from denflow.models import UNet, MLP
from denflow.degradations import *


class CfgNode(dict):
    """
    CfgNode represents an internal node in the configuration tree. It's a simple
    dict-like container that allows for attribute-based access to keys.
    """

    def __init__(self, init_dict=None, key_list=None, new_allowed=False):
        # Recursively convert nested dictionaries in init_dict into CfgNodes
        init_dict = {} if init_dict is None else init_dict
        key_list = [] if key_list is None else key_list
        for k, v in init_dict.items():
            if type(v) is dict:
                # Convert dict to CfgNode
                init_dict[k] = CfgNode(v, key_list=key_list + [k])
        super(CfgNode, self).__init__(init_dict)

    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        self[name] = value

    def __str__(self):
        def _indent(s_, num_spaces):
            s = s_.split("\n")
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(num_spaces * " ") + line for line in s]
            s = "\n".join(s)
            s = first + "\n" + s
            return s

        r = ""
        s = []
        for k, v in sorted(self.items()):
            seperator = "\n" if isinstance(v, CfgNode) else " "
            attr_str = "{}:{}{}".format(str(k), seperator, str(v))
            attr_str = _indent(attr_str, 2)
            s.append(attr_str)
        r += "\n".join(s)
        return r

    def __repr__(self):
        return "{}({})".format(
            self.__class__.__name__, super(
                CfgNode, self).__repr__())


def _decode_cfg_value(v):
    if not isinstance(v, str):
        return v
    try:
        v = literal_eval(v)
    except ValueError:
        pass
    except SyntaxError:
        pass
    return v


def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
    original_type = type(original)

    replacement_type = type(replacement)

    # The types must match (with some exceptions)
    if replacement_type == original_type:
        return replacement

    def conditional_cast(from_type, to_type):
        if replacement_type == from_type and original_type == to_type:
            return True, to_type(replacement)
        else:
            return False, None

    casts = [(tuple, list), (list, tuple)]
    try:
        casts.append((str, unicode))  # noqa: F821
    except Exception:
        pass

    for (from_type, to_type) in casts:
        converted, converted_value = conditional_cast(from_type, to_type)
        if converted:
            return converted_value

    raise ValueError(
        "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
        "key: {}".format(
            original_type, replacement_type, original, replacement, full_key
        )
    )


def load_cfg_from_cfg_file(file: str):
    cfg = {}
    assert os.path.isfile(file) and file.endswith('.yaml'), \
        '{} is not a yaml file'.format(file)

    with open(file, 'r') as f:
        cfg_from_file = yaml.safe_load(f)

    for key in cfg_from_file:
        for k, v in cfg_from_file[key].items():
            cfg[k] = v

    cfg = CfgNode(cfg)
    return cfg


def merge_cfg_from_list(cfg: CfgNode,
                        cfg_list: List[str]):
    new_cfg = copy.deepcopy(cfg)
    assert len(cfg_list) % 2 == 0, cfg_list
    for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
        subkey = full_key.split('.')[-1]
        # assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
        if subkey in cfg:
            value = _decode_cfg_value(v)
            value = _check_and_coerce_cfg_value_type(
                value, cfg[subkey], subkey, full_key
            )
            setattr(new_cfg, subkey, value)
        else:
            value = _decode_cfg_value(v)
            setattr(new_cfg, subkey, value)
    return new_cfg


def define_model(args):
    if args.dim_image == 2:
        model = MLP(dim=2, time_varying=True)
    elif args.dataset == "cifar10":

        if args.torchcfm_model:
            with torch.random.fork_rng():
                from torchcfm.models.unet.unet import UNetModelWrapper

                class UNetModelWrapperXTCifar(UNetModelWrapper):
                    def __init__(self):
                        return super().__init__(dim=(3, 32, 32), num_res_blocks=2,
                                                num_channels=128,
                                                channel_mult=[1, 2, 2, 2],
                                                num_heads=4,
                                                num_head_channels=64,
                                                attention_resolutions="16",
                                                dropout=0.1)

                    def forward(self,  x, t, y=None, *args, **kwargs):
                        return super().forward(t, x, y=y)

                model = UNetModelWrapperXTCifar()
        else:
            # try DrUNET of deepinv
            from deepinv.models.drunet import DRUNet
            model = DRUNet(in_channels=3, out_channels=3, nc=(32, 64, 128, 128), nb=2)

    else:
        model = UNet(input_channels=args.num_channels,
                     input_height=args.dim_image,
                     ch=32,
                     ch_mult=(1, 2, 4, 8),
                     num_res_blocks=6,
                     attn_resolutions=(16, 8),
                     resamp_with_conv=True,
                     )
    return (model, None)


def load_model_runid(model, run_id, model_fold="model_final", device='cuda'):
    path = f"runs:/{run_id}/{model_fold}"
    model = mlpt.load_model(path)
    model.to(device)
    return model


def save_samples(samples, train_samples, path, args):

    samples = samples.clone().permute(0, 2, 3, 1).cpu().data.numpy()
    train_samples = train_samples.clone().permute(0, 2, 3, 1).cpu().data.numpy()
    batch_samples_size = samples.shape[0]
    cols = int(math.sqrt(batch_samples_size))  # Number of columns
    rows = int(batch_samples_size / cols)   # Number of rows
    fig, ax = plt.subplots(rows, 2 * cols, figsize=(20, 20))
    for i in range(rows):
        for j in range(cols):
            if args.num_channels == 1:
                ax[i, j].imshow(samples[i + j * rows].squeeze(-1),
                                cmap='gray', vmin=0, vmax=1)
            else:
                ax[i, j].imshow(samples[i + j * rows])
    for i in range(rows):
        for j in range(cols, 2*cols+1):
            if i+(j - cols)*rows < train_samples.shape[0]:
                if args.num_channels == 1:
                    ax[i, j].imshow(train_samples[i+(j - cols)*rows].squeeze(-1),
                                    cmap='gray', vmin=0, vmax=1)
                else:
                    ax[i, j].imshow(train_samples[i+(j - cols)*rows])
    ax[0, 0].set_title("Model samples")
    ax[0, cols].set_title("Training samples")

    for ax_ in ax.flatten():
        ax_.set_xticks([])
        ax_.set_yticks([])

    plt.savefig(path),
    plt.close(fig)


def save_images(clean_img, noisy_img, rec_img, args, H_adj, iter='final'):

    clean_img = postprocess(clean_img.clone(), args)
    noisy_img = postprocess(noisy_img.clone(), args)
    rec_img = postprocess(rec_img.clone(), args)
    H_adj_noisy_img = postprocess(H_adj(torch.ones_like(noisy_img)), args)

    # save images all together
    batch_size = clean_img.shape[0]

    cols = int(math.sqrt(batch_size))  # Number of columns
    rows = int(batch_size / cols)   # Number of rows

    clean_img = clean_img.permute(0, 2, 3, 1).cpu().data
    noisy_img = noisy_img.permute(0, 2, 3, 1).cpu().data
    rec_img = rec_img.permute(0, 2, 3, 1).cpu().data
    H_adj_noisy_img = H_adj_noisy_img.permute(0, 2, 3, 1).cpu().data

    if iter != 'final':
        if batch_size == 1:
            fig = plt.figure()
            plt.imshow(rec_img[0])
        elif batch_size == 2:
            fig, ax = plt.subplots(1, 2)
            ax[0].imshow(rec_img[0].numpy())
            ax[1].imshow(rec_img[1].numpy())
            for ax_ in ax.flatten():
                ax_.set_xticks([])
                ax_.set_yticks([])
        else:
            fig, ax = plt.subplots(rows, cols, figsize=(20, 20))
            for i in range(rows):
                for j in range(cols):
                    if args.num_channels == 1:
                        ax[i, j].imshow(rec_img[i + j * rows].squeeze(-1).numpy(),
                                        cmap='gray', vmin=0, vmax=1)
                    else:
                        ax[i, j].imshow(rec_img[i + j * rows].numpy())

            for ax_ in ax.flatten():
                ax_.set_xticks([])
                ax_.set_yticks([])

        plt.savefig(os.path.join(args.save_path_ip,
                    f"{args.problem}_{args.method}_batch{args.batch}_iter{iter}.png")),
        plt.close(fig)

    list_word = ['clean', 'noisy', args.method]
    if iter == 'final':
        for k, img in enumerate([clean_img, noisy_img, rec_img]):

            if batch_size == 1:
                fig = plt.figure()
                plt.imshow(img[0].numpy())
            elif batch_size == 2:
                fig, ax = plt.subplots(1, 2)
                ax[0].imshow(img[0].numpy())
                ax[1].imshow(img[1].numpy())
            else:
                fig, ax = plt.subplots(rows, cols, figsize=(20, 20))
                for i in range(rows):
                    for j in range(cols):
                        if args.num_channels == 1:
                            ax[i, j].imshow(img[i + j * rows].squeeze(-1).numpy(),
                                            cmap='gray', vmin=0, vmax=1)
                        else:
                            ax[i, j].imshow(img[i + j * rows].numpy())

                for ax_ in ax.flatten():
                    ax_.set_xticks([])
                    ax_.set_yticks([])

            plt.savefig(os.path.join(
                args.save_path_ip, f"{args.problem}_{list_word[k]}_batch{args.batch}_final.png")),
            plt.close(fig)

    # save images one by one, in .eps, adding the name of the method (args.method) and the PSNR value to the path
    list_batch = [0, 1, 2, 5]
    # if ((args.batch < 8 and args.method == 'd_flow') or args.batch < 4) and args.eval_split == 'test' and iter == 'final':
    # if ((args.batch < 8 and args.method == 'd_flow') or args.batch in list_batch) and args.eval_split == 'test' and iter == 'final':
    # if args.eval_split == 'test' and iter == 'final':
    if args.eval_split == 'test':
        if args.batch < 4:
            print('Saving images one by one')
            for i in range(batch_size):

                if args.problem == 'superresolution' or args.problem == 'superresolution_bicubic':
                    psnr_noisy = PSNR(
                        clean_img[i], H_adj_noisy_img[i], data_range=1.)
                else:
                    psnr_noisy = PSNR(
                        clean_img[i], noisy_img[i], data_range=1.)
                psnr_rec = PSNR(clean_img[i], rec_img[i], data_range=1.)

                for k, img in enumerate([clean_img, noisy_img, rec_img]):


                    fig = plt.figure()
                    plt.imshow(img[i])
                    plt.axis('off')
                    if k == 0 and args.method == 'pnp_flow_denoiser':
                        plt.savefig(os.path.join(args.save_path_ip, f"{args.problem}_{list_word[k]}_batch{args.batch}_im{i}.png"),
                                    bbox_inches='tight', pad_inches=0)
                    if k == 1 and args.method == 'pnp_flow_denoiser':
                        plt.savefig(os.path.join(args.save_path_ip, f"{args.problem}_{list_word[k]}_batch{args.batch}_im{i}_pnsr{psnr_noisy:4.2f}.png"),
                                    bbox_inches='tight', pad_inches=0)
                    if k == 2:
                        print(os.path.join(
                            args.save_path_ip, f"{args.problem}_{list_word[k]}_batch{args.batch}_im{i}_pnsr{psnr_rec:4.2f}.png"))
                        plt.savefig(os.path.join(args.save_path_ip, f"{args.problem}_{list_word[k]}_batch{args.batch}_im{i}_iter{iter}_pnsr{psnr_rec:4.2f}.png"),
                                    bbox_inches='tight', pad_inches=0)
                    plt.close(fig)


def postprocess(img, args):
    if args.dataset == "afhq_cat":
        img = (img + 1) / 2
    else:
        invTrans = v2.Normalize(
            mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5], std=[1./0.5, 1./0.5, 1./0.5])
        img = invTrans(img)
    return img


def save_memory_use(dict_mem,  args):
    memory_filename = os.path.join(
        args.save_path_ip, f'memory_stats.txt')
    with open(memory_filename, "a") as f:
        f.write(str(dict_mem) + '\n')


def save_time_use(dict_mem,  args):
    time_filename = os.path.join(
        args.save_path_ip, f'time_stats.txt')
    with open(time_filename, "a") as f:
        f.write(str(dict_mem) + '\n')


def compute_psnr(clean_img, noisy_img, rec_img, args, H_adj, iter='final'):

    # Ensure images are in the appropriate range and format for PSNR calculation
    clean_img = postprocess(clean_img.clone(), args)
    noisy_img = postprocess(noisy_img.clone(), args)
    rec_img = postprocess(rec_img.clone(), args)
    H_adj_noisy_img = postprocess(H_adj(noisy_img), args)

    clean_img = clean_img.permute(0, 2, 3, 1).cpu().data
    if args.problem == 'superresolution' or args.problem == 'superresolution_bicubic':
        noisy_img = H_adj_noisy_img.permute(0, 2, 3, 1).cpu().data
    else:
        noisy_img = noisy_img.permute(0, 2, 3, 1).cpu().data
    rec_img = rec_img.permute(0, 2, 3, 1).cpu().data

    # Compute PSNR values
    psnr_rec = PSNR(rec_img, clean_img,  data_range=1.0, dim=(1, 2, 3))
    psnr_noisy = PSNR(noisy_img, clean_img, data_range=1.0,  dim=(1, 2, 3))

    # Save PSNR restored values
    rec_filename = os.path.join(
        args.save_path_ip, f'psnr_rec_batch{args.batch}.txt')

    with open(rec_filename, 'a') as file:
        file.write(f'{iter} {psnr_rec}\n')

    # Save PSNR noisy values
    noisy_filename = os.path.join(
        args.save_path_ip, f'psnr_noisy_batch{args.batch}.txt')

    with open(noisy_filename, 'a') as file:
        file.write(f'{iter} {psnr_noisy}\n')


def compute_average_psnr(args):
    # Compute the average PSNR values
    dict_pnsr = {}
    for word in ['rec', 'noisy']:
        psnr_by_iteration = defaultdict(list)

        for batch in range(args.max_batch):
            filename = os.path.join(
                args.save_path_ip, f'psnr_{word}_batch{batch}.txt')

            with open(filename, 'r') as f:
                for line in f:
                    iteration, psnr = map(float, line.strip().split())
                    psnr_by_iteration[int(iteration)].append(psnr)
        psnr_averages = {iteration: np.mean(
            psnrs) for iteration, psnrs in psnr_by_iteration.items()}

        avg_filename = os.path.join(
            args.save_path_ip, f'psnr_{word}_average.txt'
        )

        with open(avg_filename, 'a') as f:
            for iteration, avg_psnr in sorted(psnr_averages.items()):
                f.write(f'{iteration} {avg_psnr:.4f}\n')

        with open(avg_filename, 'r') as file:
            lines = file.readlines()
            psnr_values = [float(line.split()[1]) for line in lines]
            dict_pnsr[word] = psnr_values[-1]

    # Save final PSNR values for a given config
    filename = f'psnr_rec.png'
    with open(os.path.join(args.save_path, 'final_psnr.txt'), 'a') as file:

        # header if file is empty
        if os.stat(os.path.join(args.save_path, 'final_psnr.txt')).st_size == 0:
            file.write('psnr_rec ')
            file.write('psnr_noisy ')
            for key in args.dict_cfg_method.keys():
                file.write(f'{key} ')
            file.write('\n')

        file.write(f"{dict_pnsr['rec']} ")
        file.write(f"{dict_pnsr['noisy']} ")
        for value in args.dict_cfg_method.values():
            file.write(f'{value} ')
        file.write('\n')


# Download the LPIPS model if it is not already available
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Path where LPIPS usually stores models (default cache directory)
lpips_model_path = os.path.expanduser('~/.cache/torch/hub/checkpoints/')
if not os.path.exists(lpips_model_path) or not any(fname.startswith('alex') for fname in os.listdir(lpips_model_path)):
    print("Downloading LPIPS model for the first time...")
loss_fn_alex = lpips.LPIPS(net='alex').to(DEVICE)


def compute_lpips(clean_img, noisy_img, rec_img, args, H_adj, iter='final'):

    # Ensure images are in the appropriate range and format for LPIPS calculation
    clean_img = postprocess(clean_img.clone(), args)
    noisy_img = postprocess(noisy_img.clone(), args)
    rec_img = postprocess(rec_img.clone(), args)
    H_adj_noisy_img = postprocess(H_adj(noisy_img), args)

    # Permute images to NCHW format and move to the correct device
    clean_img = clean_img.to(DEVICE)
    rec_img = rec_img.to(DEVICE)

    if args.problem in ['superresolution', 'superresolution_bicubic']:
        noisy_img = H_adj_noisy_img.to(DEVICE)
    else:
        noisy_img = noisy_img.to(DEVICE)

    # Ensure images are in the expected format (N, C, H, W) and range [-1, 1] for LPIPS
    clean_img = 2 * clean_img - 1
    rec_img = 2 * rec_img - 1
    noisy_img = 2 * noisy_img - 1

    # Compute LPIPS values
    lpips_rec = loss_fn_alex(clean_img, rec_img, normalize=True).mean().item()
    lpips_noisy = loss_fn_alex(
        clean_img, noisy_img, normalize=True).mean().item()

    # Save LPIPS restored values
    rec_filename = os.path.join(
        args.save_path_ip, f'lpips_rec_batch{args.batch}.txt')

    with open(rec_filename, 'a') as file:
        file.write(f'{iter} {lpips_rec}\n')

    # Save LPIPS noisy values
    noisy_filename = os.path.join(
        args.save_path_ip, f'lpips_noisy_batch{args.batch}.txt')

    with open(noisy_filename, 'a') as file:
        file.write(f'{iter} {lpips_noisy}\n')


def compute_average_lpips(args):
    # Compute the average LPIPS values
    dict_lpips = {}
    for word in ['rec', 'noisy']:
        lpips_by_iteration = defaultdict(list)

        # Iterate over batches to collect LPIPS scores
        for batch in range(args.max_batch):
            filename = os.path.join(
                args.save_path_ip, f'lpips_{word}_batch{batch}.txt')

            with open(filename, 'r') as f:
                for line in f:
                    iteration, lpips = map(float, line.strip().split())
                    lpips_by_iteration[int(iteration)].append(lpips)

        # Calculate the average LPIPS score for each iteration
        lpips_averages = {iteration: np.mean(
            lpips_scores) for iteration, lpips_scores in lpips_by_iteration.items()}

        # Save the average LPIPS values to a text file
        avg_filename = os.path.join(
            args.save_path_ip, f'lpips_{word}_average.txt'
        )

        with open(avg_filename, 'a') as f:
            for iteration, avg_lpips in sorted(lpips_averages.items()):
                f.write(f'{iteration} {avg_lpips:.4f}\n')

        # Extract the last recorded LPIPS value for final comparison
        with open(avg_filename, 'r') as file:
            lines = file.readlines()
            lpips_values = [float(line.split()[1]) for line in lines]
            dict_lpips[word] = lpips_values[-1]

    # Save final LPIPS values for the given configuration
    with open(os.path.join(args.save_path, 'final_lpips.txt'), 'a') as file:
        # Write header if the file is empty
        if os.stat(os.path.join(args.save_path, 'final_lpips.txt')).st_size == 0:
            file.write('lpips_rec ')
            file.write('lpips_noisy ')
            for key in args.dict_cfg_method.keys():
                file.write(f'{key} ')
            file.write('\n')

        # Write the final average LPIPS scores and method configuration values
        file.write(f"{dict_lpips['rec']} ")
        file.write(f"{dict_lpips['noisy']} ")
        for value in args.dict_cfg_method.values():
            file.write(f'{value} ')
        file.write('\n')


def compute_ssim(clean_img, noisy_img, rec_img, args, H_adj, iter='final'):
    # Ensure images are in the appropriate range and format for SSIM calculation
    H_adj_noisy_img = postprocess(
        H_adj(noisy_img), args).cpu()
    clean_img = postprocess(clean_img.clone(), args).cpu()
    noisy_img = postprocess(noisy_img.clone(), args).cpu()
    rec_img = postprocess(rec_img.clone(), args).cpu()

    # Convert images to the appropriate format for SSIM calculation
    if args.problem == 'superresolution' or args.problem == 'superresolution_bicubic':
        noisy_img = H_adj_noisy_img
    else:
        noisy_img = noisy_img

    # Initialize SSIM metric for restored and noisy images
    ssim_metric = SSIM(data_range=1.0)
    ssim_metric_noisy = SSIM(data_range=1.0)

    # Compute SSIM values
    ssim_metric.update((rec_img, clean_img))
    ssim_rec = ssim_metric.compute()
    ssim_metric_noisy.update((noisy_img, clean_img))
    ssim_noisy = ssim_metric_noisy.compute()

    # Save SSIM restored values
    rec_filename = os.path.join(
        args.save_path_ip, f'ssim_rec_batch{args.batch}.txt')

    with open(rec_filename, 'a') as file:
        file.write(f'{iter} {ssim_rec}\n')

    # Save SSIM noisy values
    noisy_filename = os.path.join(
        args.save_path_ip, f'ssim_noisy_batch{args.batch}.txt')

    with open(noisy_filename, 'a') as file:
        file.write(f'{iter} {ssim_noisy}\n')


def compute_average_ssim(args):
    # Compute the average SSIM values
    dict_ssim = {}
    for word in ['rec', 'noisy']:
        ssim_by_iteration = defaultdict(list)

        for batch in range(args.max_batch):
            filename = os.path.join(
                args.save_path_ip, f'ssim_{word}_batch{batch}.txt')

            with open(filename, 'r') as f:
                for line in f:
                    iteration, ssim = map(float, line.strip().split())
                    ssim_by_iteration[int(iteration)].append(ssim)
        ssim_averages = {iteration: np.mean(
            ssims) for iteration, ssims in ssim_by_iteration.items()}

        avg_filename = os.path.join(
            args.save_path_ip, f'ssim_{word}_average.txt'
        )

        with open(avg_filename, 'a') as f:
            for iteration, avg_ssim in sorted(ssim_averages.items()):
                f.write(f'{iteration} {avg_ssim:.4f}\n')

        with open(avg_filename, 'r') as file:
            lines = file.readlines()
            ssim_values = [float(line.split()[1]) for line in lines]
            dict_ssim[word] = ssim_values[-1]

    # Save final SSIM values for a given config
    with open(os.path.join(args.save_path, 'final_ssim.txt'), 'a') as file:
        # header if file is empty
        if os.stat(os.path.join(args.save_path, 'final_ssim.txt')).st_size == 0:
            file.write('ssim_rec ')
            file.write('ssim_noisy ')
            for key in args.dict_cfg_method.keys():
                file.write(f'{key} ')
            file.write('\n')

        file.write(f'{dict_ssim["rec"]} ')
        file.write(f'{dict_ssim["noisy"]} ')
        for value in args.dict_cfg_method.values():
            file.write(f'{value} ')
        file.write('\n')


def compute_average_time(args):
    array_times = torch.zeros(args.max_batch)
    filename = os.path.join(
        args.save_path_ip, 'time_stats.txt')
    for batch in range(args.max_batch):
        with open(filename, 'r') as file:
            for line in file:
                # Convert the string representation of the dictionary to an actual dictionary
                data = ast.literal_eval(line.strip())
                # Check if the current batch number matches the one we're looking for
                if data['batch'] == batch:
                    array_times[batch] = data['time_per_batch']
                    break
    avg_filename = os.path.join(args.save_path_ip, f'time_average.txt')

    with open(avg_filename, 'a') as f:
        f.write(f'average time: {array_times.mean().item():.4f}\n')


def compute_average_memory(args):
    array_max_mem = torch.zeros(args.max_batch)
    filename = os.path.join(
        args.save_path_ip, 'memory_stats.txt')
    for batch in range(args.max_batch):
        with open(filename, 'r') as file:
            for line in file:
                # Convert the string representation of the dictionary to an actual dictionary
                data = ast.literal_eval(line.strip())
                # Check if the current batch number matches the one we're looking for
                if data['batch'] == batch:
                    array_max_mem[batch] = data['max_allocated']
                    break
    avg_filename = os.path.join(args.save_path_ip, f'max_memory_average.txt')

    with open(avg_filename, 'a') as f:
        f.write(f'average mem: {array_max_mem.mean().item():.4f}\n')


def get_save_path_ip(dict_cfg_method):
    """
    dict_cfg_method contains keys and values of the method.
    Return path composed of key1=value1/key2=value2/.../keyN=valueN
    """
    path = ""
    for key, value in dict_cfg_method.items():
        path = os.path.join(path, f"{key}={value}")
    return path


def colored_line(x, y, c, ax, **lc_kwargs):
    """
    Plot a line with a color specified along the line by a third value.

    It does this by creating a collection of line segments. Each line segment is
    made up of two straight lines each connecting the current (x, y) point to the
    midpoints of the lines connecting the current point with its two neighbors.
    This creates a smooth line with no gaps between the line segments.

    Parameters
    ----------
    x, y : array-like
        The horizontal and vertical coordinates of the data points.
    c : array-like
        The color values, which should be the same size as x and y.
    ax : Axes
        Axis object on which to plot the colored line.
    **lc_kwargs
        Any additional arguments to pass to matplotlib.collections.LineCollection
        constructor. This should not include the array keyword argument because
        that is set to the color argument. If provided, it will be overridden.

    Returns
    -------
    matplotlib.collections.LineCollection
        The generated line collection representing the colored line.
    """
    if "array" in lc_kwargs:
        warnings.warn(
            'The provided "array" keyword argument will be overridden')

    # Default the capstyle to butt so that the line segments smoothly line up
    default_kwargs = {"capstyle": "butt"}
    default_kwargs.update(lc_kwargs)

    # Compute the midpoints of the line segments. Include the first and last points
    # twice so we don't need any special syntax later to handle them.
    x = np.asarray(x)
    y = np.asarray(y)
    x_midpts = np.hstack((x[0], 0.5 * (x[1:] + x[:-1]), x[-1]))
    y_midpts = np.hstack((y[0], 0.5 * (y[1:] + y[:-1]), y[-1]))

    # Determine the start, middle, and end coordinate pair of each line segment.
    # Use the reshape to add an extra dimension so each pair of points is in its
    # own list. Then concatenate them to create:
    # [
    #   [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)],
    #   [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)],
    #   ...
    # ]
    coord_start = np.column_stack(
        (x_midpts[:-1], y_midpts[:-1]))[:, np.newaxis, :]
    coord_mid = np.column_stack((x, y))[:, np.newaxis, :]
    coord_end = np.column_stack((x_midpts[1:], y_midpts[1:]))[:, np.newaxis, :]
    segments = np.concatenate((coord_start, coord_mid, coord_end), axis=1)

    lc = LineCollection(segments, **default_kwargs)
    lc.set_array(c)  # set the colors of each segment

    return ax.add_collection(lc)


def plot_test_data(ax, data):
    x = data[:, 0, 0, 0]
    y = data[:, 0, 1, 0]
    ax.scatter(x, y, s=2, color='red', alpha=0.6)
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])


def plot_paths(ax, traj):
    euler_steps = traj.shape[0]
    color_range = np.linspace(0, 2, euler_steps)
    cmap = cm.get_cmap("viridis")

    print('traj', traj.shape)
    for i in range(traj.shape[1]):
        x = traj[:, i, 0, 0, 0].detach().cpu().numpy()
        y = traj[:, i, 0, 1, 0].detach().cpu().numpy()

        # Ligne colorée (supposant que colored_line utilise bien ax.plot avec cmap)
        colored_line(x, y, color_range, ax, linewidth=0.6,
                     cmap="viridis", alpha=0.75)

        # Récupérer la couleur correspondant au dernier pas
        last_color = cmap(color_range[-1] / 2)  # normaliser entre 0 et 1

        # Point final avec cette couleur
        ax.plot(x[-1], y[-1], marker='o', color=last_color, markersize=1.5)

        # Récupérer la couleur correspondant au dernier pas
        first_color = cmap(color_range[0] / 2)  # normaliser entre 0 et 1

        # Point final avec cette couleur
        ax.plot(x[0], y[0], marker='o', color=first_color, markersize=1.5)

    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])


def compute_W2(traj, test_data):
    gen_samples = traj[-1, :, :, :, :][:, 0, :, 0]
    n_gen_samples = gen_samples.shape[0]
    test_data = test_data[:n_gen_samples, 0, :, 0]
    ab = np.ones(n_gen_samples) / n_gen_samples
    M = ot.dist(gen_samples.detach().cpu().numpy(), test_data, metric="sqeuclidean")
    wasserstein_dist = np.sqrt(ot.emd2(ab, ab, M))
    return wasserstein_dist


def compute_path_energy(traj, model, device, use_trapz=True, eps=1e-12):
    # traj: [T, B, C, H, W]
    T, B = traj.shape[0], traj.shape[1]

    # endpoints flattened to (B, D)
    x0 = traj[0].reshape(B, -1).detach().cpu().numpy()
    x1 = traj[-1].reshape(B, -1).detach().cpu().numpy()

    # uniform histograms
    a = np.ones(B) / B
    b = np.ones(B) / B

    # squared Euclidean cost matrix in R^D
    M = ot.dist(x1, x0, metric="sqeuclidean")
    cost_quad_opt = ot.emd2(a, b, M)                 # ≈ W2^2(μ0, μ1)

    # time grid and Δt matching [t_start, t_end]
    t0, t1 = model.time_start, model.time_end
    time_points = torch.linspace(t0, t1, T, device=device)
    dt = (t1 - t0) / (T - 1) if T > 1 else 0.0

    # trapezoidal weights (optional)
    if use_trapz and T > 1:
        w = torch.ones(T, device=device)
        w[0] = 0.5
        w[-1] = 0.5
    else:
        w = torch.ones(T, device=device)

    # accumulate kinetic action per sample
    path_energy = torch.zeros(B, device=device)
    for i, t in enumerate(time_points):
        xt = traj[i]                      # [B, C, H, W]
        t_ = t.expand(B)                  # [B]
        v = model.get_velocity(xt, t_)    # [B, C, H, W]
        path_energy += (w[i] * dt) * torch.sum(v**2, dim=(1, 2, 3))

    path_energy_val = path_energy.mean().item()
    denom = max(cost_quad_opt, eps)
    norm_path_energy_val = (path_energy_val - cost_quad_opt) / denom

    return path_energy_val, norm_path_energy_val


def build_degradation_and_noise(args, device):
    """Gaussian-only variants (as per your simplified code)."""
    if args.problem == "denoising":
        return Denoising(), 0.2

    if args.problem == "inpainting":
        sigma_noise = 0.05
        half_size_mask = {64: 16, 128: 20, 256: 40}.get(args.dim_image)
        if half_size_mask is None:
            raise ValueError(
                f"Unsupported dim_image for inpainting: {args.dim_image}")
        return BoxInpainting(half_size_mask), sigma_noise

    if args.problem == "paintbrush_inpainting":
        return PaintbrushInpainting(), 0.05

    if args.problem == "random_inpainting":
        return RandomInpainting(p=0.7), 0.01

    if args.problem == "superresolution":
        sf = {64: 2, 128: 2, 256: 4}.get(args.dim_image)
        if sf is None:
            raise ValueError(
                f"Unsupported dim_image for superresolution: {args.dim_image}")
        return Superresolution(sf, args.dim_image), 0.05

    if args.problem == "gaussian_deblurring_FFT":
        sigma_blur = {64: 1.0, 128: 1.0, 256: 3.0}.get(args.dim_image)
        if sigma_blur is None:
            raise ValueError(
                f"Unsupported dim_image for deblurring: {args.dim_image}")
        kernel_size = 61
        degr = GaussianDeblurring(
            sigma_blur, kernel_size, "fft", args.num_channels, args.dim_image, device)
        return degr, 0.05

    raise ValueError(f"Unknown problem: {args.problem}")


def set_seed(seed: int | None) -> None:
    if seed is None:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False


def ema(source, target, decay):
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for key in source_dict.keys():
        target_dict[key].data.copy_(
            target_dict[key].data * decay + source_dict[key].data * (1 - decay)
        )


def model_mul(scale, source, target):  # scale*source -> target
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for key in source_dict.keys():
        target_dict[key].data.copy_(source_dict[key].data * scale)
