import math
import copy
from pathlib import Path
import random 
from functools import partial
from collections import namedtuple, Counter
from multiprocessing import cpu_count
import os
import numpy as np
import csv
import timeit
import json
import argparse
from collections import defaultdict
from contextlib import nullcontext
from datetime import timedelta

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import AdamW

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA

from transformers import get_scheduler, AutoTokenizer, PreTrainedTokenizerBase, T5ForConditionalGeneration, MT5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.bart.modeling_bart import BartForConditionalGeneration

from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs

#import constant as constant
from latent_lang_diff.optimizer import get_adamw_optimizer
#import utils.file_utils as file_utils

import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import numpy as np



ModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_v'])

def create_attention_mask(padded_sequence, pad_value=-1000):
    """
    Creates an attention mask for a padded sequence where the padded values are excluded.

    Args:
        padded_sequence (torch.Tensor): A tensor of shape (B, L, D) where padded values are `pad_value`.
        pad_value (int, optional): The value used for padding in the sequence. Default is -1000.

    Returns:
        torch.Tensor: An attention mask of shape (B, L), where valid positions are 1 and padded positions are 0.
    """
    # Check for non-padding values along the last dimension (D) and create a mask
    attention_mask = torch.all(padded_sequence != pad_value, dim=-1).bool()
    
    return attention_mask

# helpers functions

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def identity(t, *args, **kwargs):
    return t

def cycle(dl):
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def l2norm(t):
    return F.normalize(t, dim = -1)

def log(t, eps = 1e-12):
    return torch.log(t.clamp(min = eps))

def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# gaussian diffusion trainer class

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

# normalize variance of noised latent, if scale is not 1

def normalize_z_t_variance(z_t, mask, eps = 1e-5):
    std = rearrange([reduce(z_t[i][:torch.sum(mask[i])], 'l d -> 1 1', partial(torch.std, unbiased = False)) for i in range(z_t.shape[0])], 'b 1 1 -> b 1 1')
    return z_t / std.clamp(min = eps)
    

# noise schedules

def simple_linear_schedule(t, clip_min = 1e-9):
    return (1 - t).clamp(min = clip_min)

def beta_linear_schedule(t, clip_min = 1e-9):
    return torch.exp(-1e-4 - 10 * (t ** 2)).clamp(min = clip_min, max = 1.)

def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
    power = 2 * tau
    v_start = math.cos(start * math.pi / 2) ** power
    v_end = math.cos(end * math.pi / 2) ** power
    output = torch.cos((t * (end - start) + start) * math.pi / 2) ** power
    output = (v_end - output) / (v_end - v_start)
    return output.clamp(min = clip_min)

def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    return gamma.clamp_(min = clamp_min, max = 1.)

# converting gamma to alpha, sigma or logsnr

def log_snr_to_alpha(log_snr):
    alpha = torch.sigmoid(log_snr)
    return alpha

def alpha_to_shifted_log_snr(alpha, scale = 1):
    return log((alpha / (1 - alpha))).clamp(min=-15, max=15) + 2*np.log(scale).item()

def time_to_alpha(t, alpha_schedule, scale):
    alpha = alpha_schedule(t)
    shifted_log_snr = alpha_to_shifted_log_snr(alpha, scale = scale)
    return log_snr_to_alpha(shifted_log_snr)

def set_seeds(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model,
        *,
        max_seq_len,
        sampling_timesteps = 250,
        loss_type = 'l1',
        objective = 'pred_noise',
        train_schedule = 'cosine',
        sampling_schedule = None,
        scale = 1.,
        sampler = 'ddpm',
        train_prob_self_cond = 0.5,
        seq2seq_unconditional_prob = 0.1,
        using_latent_model = True
    ):
        super().__init__()
        assert sampler in {'ddim', 'ddpm', 'dpmpp'}, 'sampler must be one of ddim, ddpm, dpmpp'
        self.sampler = sampler

        self.diffusion_model = model
        if self.diffusion_model.class_conditional:
            if self.diffusion_model.class_unconditional_prob > 0:
                self.class_unconditional_bernoulli = torch.distributions.Bernoulli(probs=self.diffusion_model.class_unconditional_prob)

        self.latent_dim = self.diffusion_model.latent_dim
        self.self_condition = self.diffusion_model.self_condition


        self.max_seq_len = max_seq_len
        self.l2_normalize = False
        self.using_latent_model = using_latent_model

        self.objective = objective

        self.loss_type = loss_type

        assert objective in {'pred_noise', 'pred_x0', 'pred_v', 'pred_v_dual'}, 'objective must be one of pred_noise, pred_x0, pred_v, pred_v_dual'

        if train_schedule == "simple_linear":
            alpha_schedule = simple_linear_schedule
        elif train_schedule == "beta_linear":
            alpha_schedule = beta_linear_schedule
        elif train_schedule == "cosine":
            alpha_schedule = cosine_schedule
        elif train_schedule == "sigmoid":
            alpha_schedule = sigmoid_schedule
        else:
            raise ValueError(f'invalid noise schedule {train_schedule}')
        
        self.train_schedule = partial(time_to_alpha, alpha_schedule=alpha_schedule, scale=scale)

        # Sampling schedule
        if sampling_schedule is None:
            sampling_alpha_schedule = None
        elif sampling_schedule == "simple_linear":
            sampling_alpha_schedule = simple_linear_schedule
        elif sampling_schedule == "beta_linear":
            sampling_alpha_schedule = beta_linear_schedule
        elif sampling_schedule == "cosine":
            sampling_alpha_schedule = cosine_schedule
        elif sampling_schedule == "sigmoid":
            sampling_alpha_schedule = sigmoid_schedule
        else:
            raise ValueError(f'invalid sampling schedule {sampling_schedule}')
        
        if exists(sampling_alpha_schedule):
            self.sampling_schedule = partial(time_to_alpha, alpha_schedule=sampling_alpha_schedule, scale=scale)
        else:
            self.sampling_schedule = self.train_schedule

        # the main finding presented in Ting Chen's paper - that higher resolution images requires more noise for better training

        
        self.scale = scale

        # gamma schedules

        self.sampling_timesteps = sampling_timesteps

        # probability for self conditioning during training

        self.train_prob_self_cond = train_prob_self_cond
        self.seq2seq_unconditional_prob = seq2seq_unconditional_prob

        # Buffers for latent mean and scale values
        self.register_buffer('latent_mean', torch.tensor([0]*self.latent_dim).to(torch.float32))
        self.register_buffer('latent_scale', torch.tensor(1).to(torch.float32))

    def predict_start_from_noise(self, z_t, t, noise, sampling=False):
        time_to_alpha = self.sampling_schedule if sampling else self.train_schedule
        alpha = time_to_alpha(t)
        alpha = right_pad_dims_to(z_t, alpha)

        return (z_t - (1-alpha).sqrt() * noise) / alpha.sqrt().clamp(min = 1e-8)
        
    def predict_noise_from_start(self, z_t, t, x0, sampling=False):
        time_to_alpha = self.sampling_schedule if sampling else self.train_schedule
        alpha = time_to_alpha(t)
        alpha = right_pad_dims_to(z_t, alpha)

        return (z_t - alpha.sqrt() * x0) / (1-alpha).sqrt().clamp(min = 1e-8)

    def predict_start_from_v(self, z_t, t, v, sampling=False):
        time_to_alpha = self.sampling_schedule if sampling else self.train_schedule
        alpha = time_to_alpha(t)
        alpha = right_pad_dims_to(z_t, alpha)

        x = alpha.sqrt() * z_t - (1-alpha).sqrt() * v

        return x
    
    def predict_noise_from_v(self, z_t, t, v, sampling=False):
        time_to_alpha = self.sampling_schedule if sampling else self.train_schedule
        alpha = time_to_alpha(t)
        alpha = right_pad_dims_to(z_t, alpha)

        eps = (1-alpha).sqrt() * z_t + alpha.sqrt() * v

        return eps
    
    def predict_v_from_start_and_eps(self, z_t, t, x, noise, sampling=False):
        time_to_alpha = self.sampling_schedule if sampling else self.train_schedule
        alpha = time_to_alpha(t)
        alpha = right_pad_dims_to(z_t, alpha)

        v = alpha.sqrt() * noise - x* (1-alpha).sqrt()

        return v

    def normalize_latent(self, x_start):
        eps = 1e-5 
                
        return (x_start-self.latent_mean)/(self.latent_scale).clamp(min=eps)
    
    def unnormalize_latent(self, x_start):
        eps = 1e-5 
        
        return x_start*(self.latent_scale.clamp(min=eps))+self.latent_mean

    def diffusion_model_predictions(self, z_t, mask, t, *, x_self_cond = None,  class_id=None, seq2seq_cond=None, seq2seq_mask=None, sampling=False, cls_free_guidance=1.0, l2_normalize=False):
        time_to_alpha = self.sampling_schedule if sampling else self.train_schedule
        time_cond = time_to_alpha(t)
        model_output = self.diffusion_model(z_t, mask, time_cond, x_self_cond, class_id=class_id, seq2seq_cond=seq2seq_cond, seq2seq_mask=seq2seq_mask)
        if cls_free_guidance!=1.0:
            if exists(class_id):
                unc_class_id = torch.full_like(class_id, fill_value=self.diffusion_model.num_classes)
            else:
                unc_class_id = None
            unc_model_output = self.diffusion_model(z_t, mask, time_cond, x_self_cond, class_id=unc_class_id, seq2seq_cond=None, seq2seq_mask=None)
            model_output = model_output*cls_free_guidance + unc_model_output*(1-cls_free_guidance)

        pred_v = None
        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(z_t, t, pred_noise, sampling=sampling)
        elif self.objective =='pred_x0':
            x_start = model_output
            pred_noise = self.predict_noise_from_start(z_t, t, x_start, sampling=sampling)
            pred_v = self.predict_v_from_start_and_eps(z_t, t, x_start, pred_noise, sampling=sampling)
        elif self.objective == 'pred_v':
            pred_v = model_output
            x_start = self.predict_start_from_v(z_t, t, pred_v, sampling=sampling)
            pred_noise = self.predict_noise_from_v(z_t, t, pred_v, sampling=sampling)
        else:
            raise ValueError(f'invalid objective {self.objective}')
        if l2_normalize:
            assert sampling
            x_start = F.normalize(x_start, dim=-1) * math.sqrt(x_start.shape[-1])
            pred_noise = self.predict_noise_from_start(z_t, t, x_start, sampling=sampling)
            pred_v = self.predict_v_from_start_and_eps(z_t, t, x_start, pred_noise, sampling=sampling)

        return ModelPrediction(pred_noise, x_start, pred_v)

    def get_sampling_timesteps(self, batch, *, device, invert = False):
        times = torch.linspace(1., 0., self.sampling_timesteps + 1, device = device)
        if invert:
            times = times.flip(dims = (0,))
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    @torch.no_grad()
    def ddim_sample(self, shape, lengths, class_id, seq2seq_cond, seq2seq_mask, cls_free_guidance=1.0, l2_normalize=False, invert=False, z_t=None):
        print('DDIM sampling')
        batch, device = shape[0], next(self.diffusion_model.parameters()).device

        time_pairs = self.get_sampling_timesteps(batch, device = device, invert=invert)
        if invert:
            assert exists(z_t)

        if not exists(z_t):
            z_t = torch.randn(shape, device=device)

        x_start = None
        latent=None
        if self.using_latent_model:
            mask = torch.ones((shape[0], shape[1]), dtype=torch.bool, device=device)
        else:    
            mask = [[True]*length + [False]*(self.max_seq_len-length) for length in lengths]
            mask = torch.tensor(mask, dtype=torch.bool, device=device)

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.sampling_timesteps):
            # get predicted x0

            model_output = self.diffusion_model_predictions(z_t, mask, time, class_id=class_id, x_self_cond=x_start, seq2seq_cond=seq2seq_cond, seq2seq_mask=seq2seq_mask, sampling=True, cls_free_guidance=cls_free_guidance, l2_normalize=l2_normalize)
            # get alpha sigma of time and next time

            alpha = self.sampling_schedule(time)
            alpha_next = self.sampling_schedule(time_next)
            alpha, alpha_next = map(partial(right_pad_dims_to, z_t), (alpha, alpha_next))

            # # calculate x0 and noise

            x_start = model_output.pred_x_start

            eps = model_output.pred_noise

            
            if (not invert) and time_next[0] <= 0:
                z_t = x_start
                continue
            if invert and time_next[0] >= 1:
                z_t = eps
                continue
            
            # get noise
            
            z_t = x_start * alpha_next.sqrt() + eps * (1-alpha_next).sqrt()
        return (z_t, mask)


    @torch.no_grad()
    def ddpm_sample(self, shape, lengths, class_id, seq2seq_cond, seq2seq_mask, cls_free_guidance=1.0, l2_normalize=False, invert=False, z_t=None):
        batch, device = shape[0], next(self.diffusion_model.parameters()).device

        time_pairs = self.get_sampling_timesteps(batch, device = device)

        if not exists(z_t):
            z_t = torch.randn(shape, device=device)

        x_start = None
        latent=None
        if self.using_latent_model:
            mask = torch.ones((shape[0], shape[1]), dtype=torch.bool, device=device)
        else:    
            mask = [[True]*length + [False]*(self.max_seq_len-length) for length in lengths]
            mask = torch.tensor(mask, dtype=torch.bool, device=device)

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.sampling_timesteps):
            # get predicted x0

            model_output = self.diffusion_model_predictions(z_t, mask, time, class_id=class_id, x_self_cond=x_start, seq2seq_cond=seq2seq_cond, seq2seq_mask=seq2seq_mask, sampling=True, cls_free_guidance=cls_free_guidance, l2_normalize=l2_normalize)
            # get alpha sigma of time and next time

            alpha = self.sampling_schedule(time)
            alpha_next = self.sampling_schedule(time_next)
            alpha, alpha_next = map(partial(right_pad_dims_to, z_t), (alpha, alpha_next))

            alpha_now = alpha/alpha_next

            # # calculate x0 and noise

            x_start = model_output.pred_x_start

            eps = model_output.pred_noise
            
            if time_next[0] <= 0:
                z_t = x_start
                continue         
            
            # get noise

            noise = torch.randn_like(z_t)
            
            z_t = 1/alpha_now.sqrt() * (z_t - (1-alpha_now)/(1-alpha).sqrt() * eps) + torch.sqrt(1 - alpha_now) * noise
        return (z_t, mask)
    

    @torch.no_grad()
    def dpmpp_sample(self, shape, lengths, class_id, seq2seq_cond, seq2seq_mask, cls_free_guidance=1.0, l2_normalize=False, invert=False, z_t=None):
        batch, device = shape[0], next(self.diffusion_model.parameters()).device

        time_pairs = self.get_sampling_timesteps(batch, device = device)

        if not exists(z_t):
            z_t = torch.randn(shape, device=device)

        x_start = None
        latent=None
        if self.using_latent_model:
            mask = torch.ones((shape[0], shape[1]), dtype=torch.bool, device=device)
        else:    
            mask = [[True]*length + [False]*(self.max_seq_len-length) for length in lengths]
            mask = torch.tensor(mask, dtype=torch.bool, device=device)

        old_pred_x = []
        old_hs = []

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.sampling_timesteps):
            # get predicted x0

            model_output = self.diffusion_model_predictions(z_t, mask, time, class_id=class_id, x_self_cond=x_start, seq2seq_cond=seq2seq_cond, seq2seq_mask=seq2seq_mask, sampling=True, cls_free_guidance=cls_free_guidance, l2_normalize=l2_normalize)
            # get alpha sigma of time and next time

            alpha = self.sampling_schedule(time)
            alpha_next = self.sampling_schedule(time_next)
            alpha, alpha_next = map(partial(right_pad_dims_to, z_t), (alpha, alpha_next))
            sigma, sigma_next = 1-alpha, 1-alpha_next

            alpha_now = alpha/alpha_next

            lambda_now = ((log(alpha) - log(1-alpha))/2)
            lambda_next = ((log(alpha_next) - log(1-alpha_next))/2)
            h = lambda_next - lambda_now

            # calculate x0 and noise
            if time_next[0] <= 0:
                z_t = x_start
                continue  

            x_start = model_output.pred_x_start

            phi_1 = torch.expm1(-h)
            if len(old_pred_x) < 2:
                denoised_x = x_start
            else:
                h = lambda_next - lambda_now
                h_0 = old_hs[-1]
                r0 = h_0/h
                gamma = -1/(2*r0)
                denoised_x = (1-gamma)*x_start + gamma*old_pred_x[-1]
            
            z_t = (sigma_next.sqrt()/sigma.sqrt()) * z_t - alpha_next.sqrt() * phi_1 * denoised_x
        return (z_t, mask)
    

    @torch.no_grad()
    def sample(self, batch_size, length, class_id=None, seq2seq_cond=None, seq2seq_mask=None, cls_free_guidance=1.0, l2_normalize=False):
        max_seq_len, latent_dim = self.max_seq_len, self.latent_dim
        
        if self.sampler == 'ddim':
            sample_fn = self.ddim_sample
        elif self.sampler == 'ddpm':
            sample_fn = self.ddpm_sample
        elif self.sampler == 'dpmpp':
            sample_fn = self.dpmpp_sample
        else:
            raise ValueError(f'invalid sampler {self.sampler}')
        return sample_fn((batch_size, max_seq_len, latent_dim), length, class_id, seq2seq_cond, seq2seq_mask, cls_free_guidance, l2_normalize)

    @property
    def loss_fn(self):
        if self.loss_type == 'l1':
            return F.l1_loss
        elif self.loss_type == 'l2':
            return F.mse_loss
        elif self.loss_type == 'smooth_l1':
            return F.smooth_l1_loss
        else:
            raise ValueError(f'invalid loss type {self.loss_type}')

    def forward(self, txt_latent, mask, class_id, seq2seq_cond=None, seq2seq_mask=None, return_x_start=False, *args, **kwargs):
        batch, l, d, device, max_seq_len, = *txt_latent.shape, txt_latent.device, self.max_seq_len
        assert l == max_seq_len, f'length must be {self.max_seq_len}'
        
        # sample random times

        times = torch.zeros((batch,), device = device).float().uniform_(0, 1.)
        # noise sample

        noise = torch.randn_like(txt_latent)

        alpha = self.train_schedule(times)
        alpha = right_pad_dims_to(txt_latent, alpha)

        z_t = alpha.sqrt() * txt_latent + (1-alpha).sqrt() * noise

        # Perform unconditional generation with some probability
        if self.diffusion_model.class_conditional and self.diffusion_model.class_unconditional_prob > 0:
            assert exists(class_id)
            class_unconditional_mask = self.class_unconditional_bernoulli.sample(class_id.shape).bool()
            class_id[class_unconditional_mask] = self.diffusion_model.num_classes

        self_cond = None

        if self.self_condition and (random.random() < self.train_prob_self_cond):
            with torch.no_grad():
                model_output = self.diffusion_model_predictions(z_t, mask, times, class_id=class_id, seq2seq_cond=seq2seq_cond, seq2seq_mask=seq2seq_mask)
                self_cond = model_output.pred_x_start.detach()
                if self.l2_normalize:
                    self_cond = F.normalize(self_cond, dim=-1) * math.sqrt(self_cond.shape[-1])
              

        # predict and take gradient step

        predictions = self.diffusion_model_predictions(z_t, mask, times, x_self_cond=self_cond, class_id=class_id, seq2seq_cond=seq2seq_cond, seq2seq_mask=seq2seq_mask)          
        if self.objective == 'pred_x0':
            target = txt_latent
            pred = predictions.pred_x_start
        elif self.objective == 'pred_noise':
            target = noise
            pred = predictions.pred_noise
        elif self.objective == 'pred_v':
            target = alpha.sqrt() * noise - (1-alpha).sqrt() * txt_latent
            assert exists(predictions.pred_v)
            pred = predictions.pred_v
            
        loss = self.loss_fn(pred, target, reduction = 'none')
        loss = rearrange([reduce(loss[i][:torch.sum(mask[i])], 'l d -> 1', 'mean') for i in range(txt_latent.shape[0])], 'b 1 -> b 1')


        if return_x_start:
            return loss.mean(), predictions.pred_x_start
        return loss.mean()

# trainer class
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim import AdamW
import os

class Trainer(object):
    def __init__(
        self,
        args,
        diffusion,
        train_dataloader,
        val_dataloader,
        train_val_dataloader,
        *,
        train_batch_size=16,
        eval_batch_size=64,
        gradient_accumulate_every=1,
        train_lr=1e-4,
        train_num_steps=100000,
        lr_schedule='cosine',
        num_warmup_steps=500,
        ema_update_every=10,
        ema_decay=0.995,
        adam_betas=(0.9, 0.99),
        adam_weight_decay=0.01,
        save_and_sample_every=5000,
        val_every=5000,
        amp=False,
        mixed_precision='no',
        decoding_loss=False,
        decoding_loss_weight=1.0,
        results_folder='./results',
        num_samples=512,
        distributed=False,
        rank=0,
        world_size=1
    ):
        super().__init__()

        set_seeds(42)
        self.args = args
        self.rank = rank
        self.world_size = world_size
        self.distributed = distributed
        self.device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

        if self.distributed:
            self.setup_distributed()

        if not os.path.exists(results_folder):
            os.makedirs(results_folder)
        self.results_folder = results_folder
        self.val_interval = val_every

        # Initialize diffusion model
        self.diffusion = diffusion
        self.seq2seq = self.diffusion.diffusion_model.seq2seq
        self.class_conditional = self.diffusion.diffusion_model.class_conditional
        self.seq2seq_unconditional_prob = self.diffusion.seq2seq_unconditional_prob

        # Set training parameters
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.train_num_steps = train_num_steps
        self.save_and_sample_every = save_and_sample_every

        # Initialize EMA
        if self.rank == 0:  # Only the main process handles saving/loading and EMA
            self.ema = EMA(diffusion, beta=ema_decay, update_every=ema_update_every, power=3/4)

        # Dataloaders
        self.train_dataloader = self.get_distributed_dataloader(train_dataloader, rank, world_size)
        self.val_dataloader = self.get_distributed_dataloader(val_dataloader, rank, world_size)
        self.train_val_dataloader = self.get_distributed_dataloader(train_val_dataloader, rank, world_size)

        # Optimizer and scheduler
        self.opt = AdamW(diffusion.parameters(), lr=train_lr, betas=adam_betas, weight_decay=adam_weight_decay)
        self.lr_scheduler = get_scheduler(
            lr_schedule,
            optimizer=self.opt,
            num_warmup_steps=num_warmup_steps * self.world_size,
            num_training_steps=train_num_steps * self.world_size,
        )

        # Move model to device and wrap with DistributedDataParallel if needed
        self.setup_model()

        # Step counter
        self.step = 0

    def setup_distributed(self):
        """
        Initialize the process group for distributed training.
        """
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group(backend='nccl', rank=self.rank, world_size=self.world_size)

    def get_distributed_dataloader(self, loader, rank, world_size):
        """
        Wrap the DataLoader with a DistributedSampler for distributed training.
        """
        sampler = DistributedSampler(loader.dataset, num_replicas=world_size, rank=rank)
        return DataLoader(loader.dataset, sampler=sampler, batch_size=self.train_batch_size, num_workers=4)

    def setup_model(self):
        """
        Move model to device and wrap with DistributedDataParallel if using distributed training.
        """
        if self.distributed:
            torch.cuda.set_device(self.rank)
            self.diffusion = self.diffusion.to(self.rank)
            self.diffusion = DDP(self.diffusion, device_ids=[self.rank], find_unused_parameters=True)
        else:
            self.diffusion = self.diffusion.to(self.device)

    def save(self, best=False):
        if self.rank != 0:
            return

        data = {
            'step': self.step,
            'model': self.diffusion.module.state_dict() if self.distributed else self.diffusion.state_dict(),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict() if hasattr(self, 'ema') else None,
            'scheduler': self.lr_scheduler.state_dict(),
        }

        # Fix the path construction to avoid TypeError
        save_filename = 'best_model.pt' if best else 'model.pt'
        save_path = os.path.join(self.results_folder, save_filename)

        torch.save(data, save_path)


    def load(self, file_path=None, best=False, init_only=False):
        file_path = Path(file_path) if file_path else self.results_folder
        data = torch.load(str(file_path / ('best_model.pt' if best else 'model.pt')), map_location=self.device)
        
        if self.distributed:
            self.diffusion.module.load_state_dict(data['model'])
        else:
            self.diffusion.load_state_dict(data['model'])

        self.opt.load_state_dict(data['opt'])
        if hasattr(self, 'ema') and data['ema'] is not None:
            self.ema.load_state_dict(data['ema'])

        if init_only:
            return

        self.step = data['step']
        self.lr_scheduler.load_state_dict(data['scheduler'])

    def cleanup_distributed(self):
        """
        Cleanup the process group for distributed training.
        """
        dist.destroy_process_group()

    def train(self, train_dataloader, val_dataloader):
        device = torch.device(f'cuda:{self.rank}' if torch.cuda.is_available() else 'cpu')

        # Progress bar only for rank 0 (main process)
        if self.rank == 0:
            pbar = tqdm(initial=self.step, total=self.train_num_steps)

        while self.step < self.train_num_steps:
            total_loss = 0.

            # Set model to train mode
            self.diffusion.train()

            # Gradient accumulation loop
            for grad_accum_step in range(self.gradient_accumulate_every):
                data, label = next(iter(train_dataloader))
                data, label = data.to(device), label.to(device)

                seq2seq_cond, seq2seq_mask = None, None

                # If seq2seq condition is needed
                if self.seq2seq and random.random() < (1 - self.seq2seq_unconditional_prob):
                    seq2seq_cond = label  # Use the label as the seq2seq condition
                    seq2seq_mask = create_attention_mask(label).to(device)

                # Create attention mask
                mask = create_attention_mask(data).to(device)

                # Forward pass through the diffusion model
                loss = self.diffusion(data, mask, class_id=None, seq2seq_cond=seq2seq_cond, seq2seq_mask=seq2seq_mask)
                loss = loss / self.gradient_accumulate_every  # Normalize by accumulation steps
                total_loss += loss.item()

                # Backpropagation
                loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.diffusion.parameters(), max_norm=1.0)

            # Optimizer step
            self.opt.step()
            self.lr_scheduler.step()
            self.opt.zero_grad()

            # Increment step counter
            self.step += 1

            # EMA and logging only in rank 0 (main process)
            if self.rank == 0:
                logs = {
                    "loss": total_loss,
                    "learning_rate": self.lr_scheduler.get_last_lr()[0],
                    "step": self.step,
                    "epoch": self.step * self.gradient_accumulate_every / len(train_dataloader),
                    "samples": self.step * self.train_batch_size * self.gradient_accumulate_every * self.world_size,
                }

                # EMA update
                if hasattr(self, 'ema'):
                    self.ema.to(device)
                    self.ema.update()

                # Perform validation every 50 steps
                if self.step % self.val_interval == 0:
                    self.diffusion.eval()
                    total_val_loss, total_val_ema_loss = 0., 0.

                    with torch.no_grad():
                        for val_data, val_label in val_dataloader:
                            val_data, val_label = val_data.to(device), val_label.to(device)

                            # Forward pass through the diffusion model
                            mask = torch.ones(val_data.shape[:-1], dtype=torch.bool).to(device)
                            val_loss = self.diffusion(val_data, mask, class_id=None, seq2seq_cond=None, seq2seq_mask=None)
                            total_val_loss += val_loss.item()

                            # Forward pass through the EMA model
                            if hasattr(self, 'ema'):
                                ema_val_loss = self.ema.ema_model(val_data, mask, class_id=None, seq2seq_cond=None, seq2seq_mask=None)
                                total_val_ema_loss += ema_val_loss.item()

                        logs["val_loss"] = total_val_loss / len(val_dataloader)
                        logs["val_ema_loss"] = total_val_ema_loss / len(val_dataloader) if hasattr(self, 'ema') else None
                        print("Val loss:",val_loss)

                    self.diffusion.train()

                # Logging (you can replace this with actual logging tools such as TensorBoard or Weights & Biases)
                #print(f"Step {self.step}, Loss: {total_loss}")

                # Save model at regular intervals
                if self.step % self.save_and_sample_every == 0:
                    self.save()

                # Update the progress bar
                pbar.update(1)
                pbar.set_postfix(loss=total_loss)

            # Synchronize across all processes after every iteration
            if self.distributed:
                dist.barrier()

        # Final save
        if self.rank == 0:
            self.save()
            print('Training complete.')

        if self.rank == 0:
            pbar.close()

    @torch.no_grad()
    def sample_seq2seq(self, dataloader=None, num_samples=None, split='val', seed=42, num_candidates=1, cls_free_guidance=1.0, return_sample=False):
        assert split in ['train', 'val', 'test']
        num_samples = num_samples if num_samples is not None else self.num_samples
        num_candidates = num_candidates if num_candidates is not None else self.seq2seq_candidates
        device = torch.device(f'cuda:{self.rank}' if torch.cuda.is_available() else 'cpu')

        # Set model to evaluation mode
        self.ema.ema_model.eval()
        torch.manual_seed(seed)
        prefix = ''
        
        # Choose the appropriate dataloader based on the split
        if dataloader is not None:
            prefix = 'custom'
        elif split == 'val':
            dataloader = self.val_dataloader
        elif split == 'train':
            dataloader = self.train_val_dataloader
        else:
            raise ValueError(f'Invalid split {split}')

        # Add guidance information to the prefix
        prefix += f'guide{cls_free_guidance}/' if cls_free_guidance != 1.0 else ''

        # Placeholder lists for storing results
        pred_latents = []
        all_data = []

        # Generate samples from the latent diffusion model
        for batch in dataloader:
            data, label = batch
            data, label = data.to(device), label.to(device)

            # Prepare seq2seq condition and mask
            seq2seq_cond = label  # Assuming labels are conditions for seq2seq
            seq2seq_mask = create_attention_mask(label).to(device)

            # Generate candidates using the diffusion model
            for _ in range(num_candidates):
                l2_normalize = hasattr(self.ema.ema_model, 'l2_normalize_latents') and self.ema.ema_model.l2_normalize_latents

                # Sample latent outputs from the diffusion model
                latents, mask = self.ema.ema_model.sample(
                    batch_size=data.shape[0],
                    length=None,
                    seq2seq_cond=seq2seq_cond,
                    seq2seq_mask=seq2seq_mask,
                    cls_free_guidance=cls_free_guidance,
                    l2_normalize=l2_normalize
                )

                pred_latents.append(latents)
                all_data.append(data)

            # Stop after reaching the required number of samples
            if len(pred_latents) >= num_samples * num_candidates:
                break

        # Stack list of pred_latents into a tensor
        pred_latents = torch.cat(pred_latents, dim=0)[:num_samples * num_candidates]
        all_data = torch.cat(all_data, dim=0)[:num_samples * num_candidates]

        # Gather latents and data across all processes
        if self.distributed:
            pred_latents_list = [torch.zeros_like(pred_latents) for _ in range(self.world_size)]
            all_data_list = [torch.zeros_like(all_data) for _ in range(self.world_size)]
            dist.all_gather(pred_latents_list, pred_latents)
            dist.all_gather(all_data_list, all_data)
            pred_latents = torch.cat(pred_latents_list, dim=0)[:num_samples * num_candidates]
            all_data = torch.cat(all_data_list, dim=0)[:num_samples * num_candidates]

        # Only the main process (rank 0) should perform the following actions
        if self.rank == 0:
            # Evaluate MSE between data and generated latent representations
            pred_latents_flat = pred_latents.view(pred_latents.size(0), -1).cpu().numpy()
            data_flat = all_data.view(all_data.size(0), -1).cpu().numpy()

            mse = mean_squared_error(data_flat, pred_latents_flat)
            print(f"MSE between data and generated latents: {mse}")

            # t-SNE Visualization of data and predicted latents in 2D
            combined = np.vstack((data_flat, pred_latents_flat))
            tsne = TSNE(n_components=2, random_state=seed)
            tsne_result = tsne.fit_transform(combined)

            # Split the t-SNE result back into data and pred_latents
            tsne_data, tsne_pred_latents = tsne_result[:data_flat.shape[0]], tsne_result[data_flat.shape[0]:]

            # Plotting t-SNE
            plt.figure(figsize=(10, 7))
            plt.scatter(tsne_data[:, 0], tsne_data[:, 1], label="Data", alpha=0.6, c='blue')
            plt.scatter(tsne_pred_latents[:, 0], tsne_pred_latents[:, 1], label="Predicted Latents", alpha=0.6, c='red')
            plt.legend()
            plt.title('t-SNE of Data and Predicted Latents')
            plt.savefig(f'{prefix}_tsne_plot.png')
            plt.show()

            # Compute and plot the histogram of Euclidean distances
            distances = np.linalg.norm(data_flat - pred_latents_flat, axis=1)

            # Plotting histogram
            plt.figure(figsize=(10, 7))
            plt.hist(distances, bins=50, alpha=0.75, color='green')
            plt.title('Histogram of Euclidean Distances between Real and Generated Data')
            plt.xlabel('Distance')
            plt.ylabel('Frequency')
            plt.savefig(f'{prefix}_distance_histogram.png')
            plt.show()

            # Compute cosine similarity between data and predicted latents
            cos_sim = np.mean([np.dot(data_flat[i], pred_latents_flat[i]) / (np.linalg.norm(data_flat[i]) * np.linalg.norm(pred_latents_flat[i]) + 1e-8) for i in range(data_flat.shape[0])])
            print(f"Cosine Similarity between data and generated latents: {cos_sim}")

            # Save generated latents if needed
            if split == 'test':
                save_path = os.path.join(self.results_folder, f'{prefix}_seq2seq_{split}_latents.pt')
                torch.save(pred_latents, save_path)

            # Log the metrics
            metrics = {
                f"model/seq2seq/{prefix}mse": mse,
                f"model/seq2seq/{prefix}cosine_similarity": cos_sim
            }
            print(metrics)

        # Synchronize across all processes
        if self.distributed:
            dist.barrier()

        # Optionally return the generated latent samples
        if return_sample:
            return pred_latents

        # Clean up the CUDA memory
        torch.cuda.empty_cache()
