import sys
import os
import torchvision.transforms as transforms
from torchvision import datasets
import torch

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from architectures.mlp import Generator as Generator_mlp
from architectures.mlp import Critic as Critic_mlp
from magnitude import *
from model_loader import ModelLoader

class GAN:
    def __init__(self, batch_size=64, lr=0.00005, n_cpu: int = 8, latent_dim: int = 100, img_size: int = 28, channels: int = 1, n_critic: int = 5, step: int = 50, device=None, name: str = 'WGAN_mlp', dataset_name: str = 'MNIST'):
        self.batch_size = batch_size
        self.lr = lr
        self.n_cpu = n_cpu
        self.latent_dim = latent_dim
        self.img_size = img_size
        self.channels = channels
        self.n_critic = n_critic
        self.step = step
        self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
        self.dataset_name = dataset_name 

        
        if dataset_name == 'MNIST':
            print(f'Preparing {dataset_name}')
            self.img_shape = (self.channels, self.img_size, self.img_size)
            mnist_dir = os.path.join(os.path.dirname(__file__), '../mnist_data')
            mnist_dir = os.path.abspath(mnist_dir)
            os.makedirs(mnist_dir, exist_ok=True)
            self.dataloader = torch.utils.data.DataLoader(
                datasets.MNIST(
                    mnist_dir,
                    train=True,
                    download=True,
                    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
                ),
                batch_size=self.batch_size,
                shuffle=True,
            )

        self.model_name = name
        if name.endswith('mlp'):
            self.G = Generator_mlp(self.latent_dim, self.img_shape).to(self.device)
            self.C = Critic_mlp(self.img_shape).to(self.device)

            self.optimizer_G = torch.optim.RMSprop(self.G.parameters(), lr=self.lr)
            self.optimizer_C = torch.optim.RMSprop(self.C.parameters(), lr=self.lr)

        self.loss_C_list = []
        self.loss_G_list = []
        self.generator_grad_norm_list = []
        self.gen_data_list = []

    #Calculates the gradient penalty loss for WGAN GP
    def gradient_penalty(self, critic, real_samples, fake_samples, device):
        alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
        d_interpolates = critic(interpolates)

        fake = torch.ones(d_interpolates.shape).to(device)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    
    def train_WGAN(self, n_epochs, n_critic, batch_size: int = 64, clip_value=0.01): 
        self.n_epochs = n_epochs
        self.clip_value = clip_value
        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')


        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()
                # Compute loss for critic
                loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))
                loss_C.backward()
                self.optimizer_C.step()
                # Clip critic weights, note that this is not the same as the gradient clipping
                for p in self.C.parameters():
                    p.data.clamp_(-self.clip_value, self.clip_value)
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                clip_value = self.clip_value,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            clip_value = self.clip_value, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list

    def train_WGAN_GP(self, n_epochs, n_critic, batch_size: int = 64, lambda_gp = 10): 
        self.n_epochs = n_epochs
        self.lambda_gp = lambda_gp
        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')


        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()

                gp = self.gradient_penalty(self.C, real_imgs, fake_imgs, self.device)
                penalty = lambda_gp * gp
                # Compute loss for critic
                loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs)) + penalty
                loss_C.backward()
                self.optimizer_C.step()
                
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                lambda_gp = self.lambda_gp,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}")


        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            lambda_gp = self.lambda_gp, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list
            

    def train_MagGAN(self, n_epochs, n_critic, batch_size: int = 64, clip_value=0.01, t=1.0, normalize=False, mode: str = 'beta_critic', beta = [0.1, 0.1], minimax = False): 
        self.n_epochs = n_epochs
        self.clip_value = clip_value

        # Set model name based on parameters
        self.model_name = self._get_maggan_model_name(mode = mode, beta = beta, t = t, normalize = normalize, name = self.model_name, minimax = minimax)

        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')

        sign = 1
        if minimax:
            sign = -1

        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()

                # Compute loss for critic
                # print(f'Real images shape: {real_imgs.shape}')
                # print(f'Fake images shape: {fake_imgs.shape}')

                loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), fake_imgs.view(fake_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                loss_distance = sign * loss_distance
                # Compute loss for Critic based on Magnitude distance
                if mode == 'beta_critic' or mode == 'beta':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs)) + beta[0] * loss_distance
                elif mode == 'generator_regularization':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))

                # loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))
                loss_C.backward()
                self.optimizer_C.step()
                # Clip critic weights, note that this is not the same as the gradient clipping
                for p in self.C.parameters():
                    p.data.clamp_(-self.clip_value, self.clip_value)
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    real_imgs = imgs.to(self.device)

                    loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                    if mode == 'beta_critic':
                        loss_G = -torch.mean(self.C(gen_imgs))
                    elif mode == 'generator_regularization':
                        loss_G = -torch.mean(self.C(gen_imgs)) + loss_distance
                    elif mode == 'beta':
                        loss_G = -torch.mean(self.C(gen_imgs)) + beta[1] * loss_distance

                    # loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms

                    # After: Extracted to private method
                    
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                clip_value = self.clip_value,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            clip_value = self.clip_value, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list
        


    def train_MagGAN_GP(self, n_epochs, n_critic, batch_size: int = 64, lambda_gp = 10, t=1.0, normalize=False, mode: str = 'beta_critic', beta = [0.1, 0.1], minimax = False): 
        self.n_epochs = n_epochs
        self.lambda_gp = lambda_gp
        
        # Set model name based on parameters
        self.model_name = self._get_maggan_model_name(mode = mode, beta = beta, t = t, normalize = normalize, name = self.model_name, minimax = minimax)

        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')

        sign = 1
        if minimax:
            sign = -1

        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()
                # Compute loss for critic
                # print(f'Real images shape: {real_imgs.shape}')
                # print(f'Fake images shape: {fake_imgs.shape}')

                loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), fake_imgs.view(fake_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                loss_distance = sign * loss_distance
                # Compute loss for Critic based on Magnitude distance
                if mode == 'beta_critic' or mode == 'beta':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs)) + beta[0] * loss_distance
                elif mode == 'generator_regularization':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))

                gp = self.gradient_penalty(self.C, real_imgs, fake_imgs, self.device)
                penalty = lambda_gp * gp
                # Compute loss for critic
                loss_C = loss_C + penalty

                # loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))
                loss_C.backward()
                self.optimizer_C.step()
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    real_imgs = imgs.to(self.device)

                    loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                    if mode == 'beta_critic':
                        loss_G = -torch.mean(self.C(gen_imgs))
                    elif mode == 'generator_regularization':
                        loss_G = -torch.mean(self.C(gen_imgs)) + loss_distance
                    elif mode == 'beta':
                        loss_G = -torch.mean(self.C(gen_imgs)) + beta[1] * loss_distance

                    # loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                lambda_gp = self.lambda_gp,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            lambda_gp = self.lambda_gp, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list


    def train_MagGAN_with_scheduler(self, n_epochs, n_critic, batch_size: int = 64, clip_value=0.01, t=1.0, normalize=False, mode: str = 'beta_critic', beta = [0.1, 0.1], minimax = False, beta_schedule_config = None): 
        self.n_epochs = n_epochs
        self.clip_value = clip_value

        # Set model name based on parameters
        self.model_name = self._get_maggan_model_name(mode = mode, beta = beta, t = t, normalize = normalize, name = self.model_name, minimax = minimax, beta_schedule_config = beta_schedule_config)

        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')

        sign = 1
        if minimax:
            sign = -1
        # Initialize current beta values
        current_beta = beta.copy()
        initial_beta_0 = beta[0]
        
        # Track beta values for logging
        self.beta_schedule_list = []

        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            prev_beta = current_beta[0]
            current_beta[0] = self.schedule_beta(epoch, initial_beta_0, beta_schedule_config)
            
            # Log beta changes
            if epoch == 1 or current_beta[0] != prev_beta:
                print(f"Epoch {epoch}: beta[0] = {current_beta[0]:.6f}")
            
            self.beta_schedule_list.append(current_beta[0])
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()

                # Compute loss for critic
                # print(f'Real images shape: {real_imgs.shape}')
                # print(f'Fake images shape: {fake_imgs.shape}')

                loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), fake_imgs.view(fake_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                loss_distance = sign * loss_distance
                # Compute loss for Critic based on Magnitude distance
                if mode == 'beta_critic' or mode == 'beta':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs)) + current_beta[0] * loss_distance
                elif mode == 'generator_regularization':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))

                # loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))
                loss_C.backward()
                self.optimizer_C.step()
                # Clip critic weights, note that this is not the same as the gradient clipping
                for p in self.C.parameters():
                    p.data.clamp_(-self.clip_value, self.clip_value)
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    real_imgs = imgs.to(self.device)

                    loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                    if mode == 'beta_critic':
                        loss_G = -torch.mean(self.C(gen_imgs))
                    elif mode == 'generator_regularization':
                        loss_G = -torch.mean(self.C(gen_imgs)) + loss_distance
                    elif mode == 'beta':
                        loss_G = -torch.mean(self.C(gen_imgs)) + beta[1] * loss_distance

                    # loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms

                    # After: Extracted to private method
                    
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                clip_value = self.clip_value,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}, beta[0]: {current_beta[0]:.6f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            clip_value = self.clip_value, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list
        


    def train_MagGAN_GP_with_scheduler(self, n_epochs, n_critic, batch_size: int = 64, lambda_gp = 10, t=1.0, normalize=False, mode: str = 'beta_critic', beta = [0.1, 0.1], minimax = False, beta_schedule_config = None): 
        self.n_epochs = n_epochs
        self.lambda_gp = lambda_gp
        
        # Set model name based on parameters
        self.model_name = self._get_maggan_model_name(mode = mode, beta = beta, t = t, normalize = normalize, name = self.model_name, minimax = minimax, beta_schedule_config = beta_schedule_config)

        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')

        sign = 1
        if minimax:
            sign = -1
        # Initialize current beta values
        current_beta = beta.copy()
        initial_beta_0 = beta[0]
        
        # Track beta values for logging
        self.beta_schedule_list = []

        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            prev_beta = current_beta[0]
            current_beta[0] = self.schedule_beta(epoch, initial_beta_0, beta_schedule_config)
            
            # Log beta changes
            if epoch == 1 or current_beta[0] != prev_beta:
                print(f"Epoch {epoch}: beta[0] = {current_beta[0]:.6f}")
            
            self.beta_schedule_list.append(current_beta[0])
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()
                # Compute loss for critic
                # print(f'Real images shape: {real_imgs.shape}')
                # print(f'Fake images shape: {fake_imgs.shape}')

                loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), fake_imgs.view(fake_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                loss_distance = sign * loss_distance
                # Compute loss for Critic based on Magnitude distance
                if mode == 'beta_critic' or mode == 'beta':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs)) + current_beta[0] * loss_distance
                elif mode == 'generator_regularization':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))

                gp = self.gradient_penalty(self.C, real_imgs, fake_imgs, self.device)
                penalty = lambda_gp * gp
                # Compute loss for critic
                loss_C = loss_C + penalty

                # loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))
                loss_C.backward()
                self.optimizer_C.step()
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    real_imgs = imgs.to(self.device)

                    loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=t, normalize=normalize, eps = 0)
                    if mode == 'beta_critic':
                        loss_G = -torch.mean(self.C(gen_imgs))
                    elif mode == 'generator_regularization':
                        loss_G = -torch.mean(self.C(gen_imgs)) + loss_distance
                    elif mode == 'beta':
                        loss_G = -torch.mean(self.C(gen_imgs)) + beta[1] * loss_distance

                    # loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                lambda_gp = self.lambda_gp,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}, beta[0]: {current_beta[0]:.6f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            lambda_gp = self.lambda_gp, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list

# scheduling t

    def train_MagGAN_with_t_scheduler(self, n_epochs, n_critic, batch_size: int = 64, clip_value=0.01, normalize=False, mode: str = 'beta_critic', beta = [0.1, 0.1], minimax = False, t_schedule_config = None): 
        self.n_epochs = n_epochs
        self.clip_value = clip_value

        # Set model name based on parameters
        self.model_name = self._get_maggan_model_name(mode = mode, beta = beta, t = 0, normalize = normalize, name = self.model_name, minimax = minimax,t_schedule_config = t_schedule_config)

        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')

        sign = 1
        if minimax:
            sign = -1
        
        # Track beta values for logging
        self.t_schedule_list = []
        current_t = t_schedule_config.get('initial_t', 1.0)

        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            prev_t = current_t
            current_t = self.schedule_t(epoch, t_schedule_config)
            
            # Log beta changes
            if epoch == 1 or current_t != prev_t:
                print(f"Epoch {epoch}: t = {current_t:.6f}")
            
            self.t_schedule_list.append(current_t)
            self.model_loader.set_paths(epoch)
            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()

                # Compute loss for critic
                # print(f'Real images shape: {real_imgs.shape}')
                # print(f'Fake images shape: {fake_imgs.shape}')

                loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), fake_imgs.view(fake_imgs.size(0), -1), device=self.device, t=current_t, normalize=normalize, eps = 0)
                loss_distance = sign * loss_distance
                # Compute loss for Critic based on Magnitude distance
                if mode == 'beta_critic' or mode == 'beta':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs)) + beta[0] * loss_distance
                elif mode == 'generator_regularization':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))

                # loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))
                loss_C.backward()
                self.optimizer_C.step()
                # Clip critic weights, note that this is not the same as the gradient clipping
                for p in self.C.parameters():
                    p.data.clamp_(-self.clip_value, self.clip_value)
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    real_imgs = imgs.to(self.device)

                    loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=current_t, normalize=normalize, eps = 0)
                    if mode == 'beta_critic':
                        loss_G = -torch.mean(self.C(gen_imgs))
                    elif mode == 'generator_regularization':
                        loss_G = -torch.mean(self.C(gen_imgs)) + loss_distance
                    elif mode == 'beta':
                        loss_G = -torch.mean(self.C(gen_imgs)) + beta[1] * loss_distance

                    # loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms

                    # After: Extracted to private method
                    
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                clip_value = self.clip_value,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}, beta[0]: {beta[0]:.6f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            clip_value = self.clip_value, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list
        


    def train_MagGAN_GP_with_t_scheduler(self, n_epochs, n_critic, batch_size: int = 64, lambda_gp = 10, normalize=False, mode: str = 'beta_critic', beta = [0.1, 0.1], minimax = False, t_schedule_config = None): 
        self.n_epochs = n_epochs
        self.lambda_gp = lambda_gp
        
        # Set model name based on parameters
        self.model_name = self._get_maggan_model_name(mode = mode, beta = beta, t = 0, normalize = normalize, name = self.model_name, minimax = minimax, t_schedule_config = t_schedule_config)

        self.model_loader = ModelLoader(self.dataset_name, self.model_name, self.device, self.step)
        self.model_loader.set_paths(self.n_epochs)

        if os.path.exists(self.model_loader.critic_path) and os.path.exists(self.model_loader.generator_path) and os.path.exists(self.model_loader.list_path):
            return self.model_loader.load_model(self.n_epochs, self.C, self.G)
        print(f'Critic path: {self.model_loader.critic_path}, exist: {os.path.exists(self.model_loader.critic_path)}')
        print(f'Generator path: {self.model_loader.generator_path}, exist: {os.path.exists(self.model_loader.generator_path)}')
        print(f'Loss and gradient norms path: {self.model_loader.list_path}, exist: {os.path.exists(self.model_loader.list_path)}')
        print(f'Starting training {self.model_name} for {self.n_epochs} epochs with batch size {batch_size} and {n_critic} critic updates.')

        sign = 1
        if minimax:
            sign = -1
        
        # Track beta values for logging
        self.t_schedule_list = []
        current_t = t_schedule_config.get('initial_t', 1.0)

        # batches_done = 0
        for epoch in range(1, self.n_epochs+1):
            prev_t = current_t
            current_t = self.schedule_t(epoch, t_schedule_config)
            
            # Log beta changes
            if epoch == 1 or current_t != prev_t:
                print(f"Epoch {epoch}: t = {current_t:.6f}")
            
            self.t_schedule_list.append(current_t)
            self.model_loader.set_paths(epoch)

            for i, (imgs, _) in enumerate(self.dataloader):  
                real_imgs = imgs.to(self.device)
                self.optimizer_C.zero_grad()
                # Sample noise as generator input
                z = torch.randn(imgs.shape[0], self.latent_dim, device=self.device)
                # Generate a batch of images
                fake_imgs = self.G(z).detach()
                # Compute loss for critic
                # print(f'Real images shape: {real_imgs.shape}')
                # print(f'Fake images shape: {fake_imgs.shape}')

                loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), fake_imgs.view(fake_imgs.size(0), -1), device=self.device, t=current_t, normalize=normalize, eps = 0)
                loss_distance = sign * loss_distance
                # Compute loss for Critic based on Magnitude distance
                if mode == 'beta_critic' or mode == 'beta':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs)) + beta[0] * loss_distance
                elif mode == 'generator_regularization':
                    loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))

                gp = self.gradient_penalty(self.C, real_imgs, fake_imgs, self.device)
                penalty = lambda_gp * gp
                # Compute loss for critic
                loss_C = loss_C + penalty

                # loss_C = -torch.mean(self.C(real_imgs)) + torch.mean(self.C(fake_imgs))
                loss_C.backward()
                self.optimizer_C.step()
                

                if i % n_critic == 0:
                    self.optimizer_G.zero_grad()
                    gen_imgs = self.G(z)
                    real_imgs = imgs.to(self.device)

                    loss_distance = norm_diff_magnitude_distance_grad(real_imgs.view(real_imgs.size(0), -1), gen_imgs.view(gen_imgs.size(0), -1), device=self.device, t=current_t, normalize=normalize, eps = 0)
                    if mode == 'beta_critic':
                        loss_G = -torch.mean(self.C(gen_imgs))
                    elif mode == 'generator_regularization':
                        loss_G = -torch.mean(self.C(gen_imgs)) + loss_distance
                    elif mode == 'beta':
                        loss_G = -torch.mean(self.C(gen_imgs)) + beta[1] * loss_distance

                    # loss_G = -torch.mean(self.C(gen_imgs))
                    loss_G.backward()
                    # Track generator gradient norms
                    total_norm = self._calculate_gradient_norm(self.G)
                    self.optimizer_G.step()

                    self.loss_C_list.append(loss_C.item())
                    self.loss_G_list.append(loss_G.item())
                    self.generator_grad_norm_list.append(total_norm)


            if (epoch) % self.step == 0:
                self.gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=50, save=False)
                self.gen_data_list.append(self.gen_data)
                # Save model, losses, and generated data
                self.model_loader.save_model(self, epoch, self.batch_size, self.n_critic, 
                lambda_gp = self.lambda_gp,  info=False) # Use clip_value for GAN_mnist 
                print(f"Epoch {epoch}/{self.n_epochs}| D loss: {loss_C.item():.6f} | G loss: {loss_G.item():.6f} | G grad norm: {total_norm:.4f}, beta[0]: {beta[0]:.6f}")


                # if batches_done % (self.step*8)== 0:
                #     save_image(gen_imgs.data[:25], f"images/{batches_done}.png", nrow=5, normalize=True)
                # batches_done += 1

        self.model_loader.save_model(self, self.n_epochs, batch_size, n_critic,
            lambda_gp = self.lambda_gp, info=True)
        return self.loss_C_list, self.loss_G_list, self.generator_grad_norm_list, self.gen_data_list






    def schedule_beta(self, epoch, initial_beta, schedule_config):
        """
        Schedule beta value based on epoch and configuration.
        
        Args:
            epoch: Current epoch number
            initial_beta: Initial beta value
            schedule_config: Dictionary with scheduling parameters:
                - 'type': 'constant', 'step', 'linear', 'exponential', 'cosine'
                - 'interval': Number of epochs between updates (for step schedule)
                - 'milestones': List of (epoch, beta_value) tuples (for step schedule)
                - 'final_beta': Final beta value (for linear/exponential/cosine)
                - 'warmup_epochs': Number of epochs to warm up (optional)
        
        Returns:
            Current beta value
        """
        schedule_type = schedule_config.get('type', 'constant')
        
        # Warmup phase
        warmup_epochs = schedule_config.get('warmup_epochs', 0)
        if epoch <= warmup_epochs:
            if warmup_epochs > 0:
                return initial_beta * (epoch / warmup_epochs)
            return initial_beta
        
        # Adjust epoch for warmup
        adjusted_epoch = epoch - warmup_epochs
        adjusted_n_epochs = self.n_epochs - warmup_epochs
        
        if schedule_type == 'constant':
            return initial_beta
        
        elif schedule_type == 'step':
            # Step schedule with milestones
            milestones = schedule_config.get('milestones', [])
            current_beta = initial_beta
            for milestone_epoch, beta_value in milestones:
                if epoch >= milestone_epoch:
                    current_beta = beta_value
            return current_beta
        
        elif schedule_type == 'interval_step':
            # Increase beta at regular intervals
            interval = schedule_config.get('interval', 50)
            final_beta = schedule_config.get('final_beta', 1.0)
            step_size = schedule_config.get('step_size', None)
            
            num_steps = adjusted_epoch // interval
            if step_size is None:
                # Calculate step size to reach final_beta
                total_steps = adjusted_n_epochs // interval
                if total_steps > 0:
                    step_size = (final_beta - initial_beta) / total_steps
                else:
                    step_size = 0
            
            current_beta = initial_beta + num_steps * step_size
            return min(current_beta, final_beta)
        
        elif schedule_type == 'linear':
            # Linear schedule from initial to final
            final_beta = schedule_config.get('final_beta', 1.0)
            if adjusted_n_epochs > 0:
                current_beta = initial_beta + (final_beta - initial_beta) * (adjusted_epoch / adjusted_n_epochs)
            else:
                current_beta = initial_beta
            return current_beta
        
        elif schedule_type == 'exponential':
            # Exponential schedule
            final_beta = schedule_config.get('final_beta', 1.0)
            if adjusted_n_epochs > 0 and initial_beta > 0:
                growth_rate = (final_beta / initial_beta) ** (1 / adjusted_n_epochs)
                current_beta = initial_beta * (growth_rate ** adjusted_epoch)
            else:
                current_beta = initial_beta
            return current_beta
        
        elif schedule_type == 'cosine':
            # Cosine annealing schedule
            import math
            final_beta = schedule_config.get('final_beta', 1.0)
            if adjusted_n_epochs > 0:
                cosine_factor = 0.5 * (1 + math.cos(math.pi * adjusted_epoch / adjusted_n_epochs))
                current_beta = final_beta + (initial_beta - final_beta) * cosine_factor
            else:
                current_beta = initial_beta
            return current_beta
        
        else:
            print(f"Warning: Unknown schedule type '{schedule_type}', using constant schedule")
            return initial_beta

    def schedule_t(self, epoch, schedule_config):
        """
        Schedule t value based on epoch and configuration.
        
        Args:
            epoch: Current epoch number
            initial_t: Initial t value
            schedule_config: Dictionary with scheduling parameters:
                - 'type': 'constant', 'step', 'linear', 'exponential', 'cosine'
                - 'interval': Number of epochs between updates (for step schedule)
                - 'milestones': List of (epoch, t_value) tuples (for step schedule)
                - 'final_t': Final t value (for linear/exponential/cosine)
                - 'warmup_epochs': Number of epochs to warm up (optional)
        
        Returns:
            Current t value
        """
        schedule_type = schedule_config.get('type', 'constant')
        initial_t = schedule_config.get('initial_t', 0.25)
        final_t = schedule_config.get('final_t', 5.0)
        
        # Warmup phase
        warmup_epochs = schedule_config.get('warmup_epochs', 0)
        if epoch <= warmup_epochs:
            if warmup_epochs > 0:
                return initial_t * (epoch / warmup_epochs)
            return initial_t
        
        # Adjust epoch for warmup
        adjusted_epoch = epoch - warmup_epochs
        adjusted_n_epochs = self.n_epochs - warmup_epochs
        
        if schedule_type == 'constant':
            return initial_t
        
        elif schedule_type == 'step':
            # Step schedule with milestones
            milestones = schedule_config.get('milestones', [])
            current_t = initial_t
            for milestone_epoch, t_value in milestones:
                if epoch >= milestone_epoch:
                    current_t = t_value
            return current_t
        
        elif schedule_type == 'interval_step':
            # Increase t at regular intervals
            interval = schedule_config.get('interval', 50)
            final_t = schedule_config.get('final_t', 1.0)
            step_size = schedule_config.get('step_size', None)
            
            num_steps = adjusted_epoch // interval
            if step_size is None:
                # Calculate step size to reach final_t
                total_steps = adjusted_n_epochs // interval
                if total_steps > 0:
                    step_size = (final_t - initial_t) / total_steps
                else:
                    step_size = 0
            
            current_t = initial_t + num_steps * step_size
            return min(current_t, final_t)
        
        elif schedule_type == 'linear':
            # Linear schedule from initial to final
            final_t = schedule_config.get('final_t', 1.0)
            if adjusted_n_epochs > 0:
                current_t = initial_t + (final_t - initial_t) * (adjusted_epoch / adjusted_n_epochs)
            else:
                current_t = initial_t
            return current_t
        
        elif schedule_type == 'exponential':
            # Exponential schedule
            final_t = schedule_config.get('final_t', 1.0)
            if adjusted_n_epochs > 0 and initial_t > 0:
                growth_rate = (final_t / initial_t) ** (1 / adjusted_n_epochs)
                current_t = initial_t * (growth_rate ** adjusted_epoch)
            else:
                current_t = initial_t
            return current_t
        
        elif schedule_type == 'cosine':
            # Cosine annealing schedule
            import math
            final_t = schedule_config.get('final_t', 1.0)
            if adjusted_n_epochs > 0:
                cosine_factor = 0.5 * (1 + math.cos(math.pi * adjusted_epoch / adjusted_n_epochs))
                current_t = final_t + (initial_t - final_t) * cosine_factor
            else:
                current_t = initial_t
            return current_t
        
        else:
            print(f"Warning: Unknown schedule type '{schedule_type}', using constant schedule")
            return initial_t

    def _calculate_gradient_norm(self, model):
        """Calculate the total gradient norm for a model."""
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        return total_norm ** 0.5

    def _get_maggan_model_name(self, mode, beta, t, normalize, name, minimax, beta_schedule_config = None, t_schedule_config = None):
        """Generate model name for Magnitude GAN based on parameters."""
        self.model_name = f'{name}_{mode}'
        if minimax:
            self.model_name = f'{name}_minimax_{mode}'

        if beta_schedule_config is not None:
            schedule_type = beta_schedule_config.get('type', 'constant')
            final_beta = beta_schedule_config.get('final_beta', None)
            warmup_epochs = beta_schedule_config.get('warmup_epochs', 0)

            if mode == 'beta_critic':
                self.model_name = f'{self.model_name}-{schedule_type}-sched-{beta[0]}_to_{final_beta}_with_warmup_{warmup_epochs}'
            elif mode == 'beta':
                self.model_name = f'{self.model_name}-{schedule_type}-sched-{beta}_to_{final_beta}_with_warmup_{warmup_epochs}'
            self.model_name = f'{self.model_name}-mode_t{t}'

        elif t_schedule_config is not None:
            schedule_type = t_schedule_config.get('type', 'constant')
            initial_t = t_schedule_config.get('initial_t', None)
            final_t = t_schedule_config.get('final_t', None)
            warmup_epochs = t_schedule_config.get('warmup_epochs', 0)

            if mode == 'beta_critic':
                self.model_name = f'{self.model_name}-{beta[0]}-mode_t{schedule_type}-sched-{initial_t}_to_{final_t}_with_warmup_{warmup_epochs}'
            elif mode == 'beta':
                self.model_name = f'{self.model_name}-{beta}-mode_t{schedule_type}-sched-{initial_t}_to_{final_t}_with_warmup_{warmup_epochs}'
            
        else:
            if mode == 'beta_critic':
                self.model_name = f'{self.model_name}-{beta[0]}'
            elif mode == 'beta':
                self.model_name = f'{self.model_name}-{beta}'
            self.model_name = f'{self.model_name}-mode_t{t}'
        if normalize:
            self.model_name = f'Norm{self.model_name}'
        

        return self.model_name
    

    def compute_magnitude_overlap(self, epoch, max_t = 5, min_t = 0, steps = 10, num_samples = 1000, normalize = False):
        gen_data = self.model_loader.generated_images(self.G, self.latent_dim, self.device, epoch, num_samples=num_samples, save=False)

        for imgs, _ in self.dataloader:
            real_imgs = imgs.to(self.device)
            break

        max_overlap, t_arg_max = norm_max_magnitude_overlap_grad(real_imgs.view(real_imgs.size(0), -1), gen_data.view(gen_data.size(0), -1), device=self.device, normalize=normalize, eps = 0, max_t = max_t, min_t = min_t, steps = steps)
        print(f'Max magnitude overlap at t={t_arg_max}: {max_overlap}')
        return max_overlap, t_arg_max
