import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torchvision.utils import save_image

from utils import ExponentialMovingAverage
from data.load_dataset import load_dataset
from global_config import ROOT_DIRECTORY
import argparse


class DDPMTrainer:
    """Trainer class for MNIST Diffusion Models."""

    def __init__(self, learning_rate, batch_size, training_epoch, denoising_module, dataset_name, train_data_loader,
                 test_data_loader, timesteps, device):
        """
        Initialize the MNIST Diffusion Trainer with the given configuration.

        :param config: TrainingConfig instance containing all training parameters.
        """

        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.device = device
        self.training_epoch = training_epoch
        self.task_name = "ddpm"
        self.dataset_name = dataset_name
        self.base_path = os.path.join(ROOT_DIRECTORY, "results", self.dataset_name, self.task_name)
        os.makedirs(self.base_path, exist_ok=True)
        self.ckpt_path = os.path.join(self.base_path, "model.pt")

        # Load datasets
        self.train_dataloader = train_data_loader
        self.test_dataloader = test_data_loader

        # Initialize model
        self.model = denoising_module.to(self.device)

        self.model_ema_steps = 10
        self.model_ema_decay = 0.995

        self.log_freq = 10

        # Initialize Exponential Moving Average (EMA) for the model
        adjust = self.batch_size * self.model_ema_steps / self.training_epoch
        alpha = 1.0 - self.model_ema_decay
        alpha = min(1.0, alpha * adjust)
        self.model_ema = ExponentialMovingAverage(
            self.model, device=self.device, decay=1.0 - alpha
        )

        # Initialize optimizer and scheduler
        if list(self.model.parameters()):  # Check if model has trainable parameters
            self.optimizer = AdamW(self.model.parameters(), lr=self.learning_rate)
            self.scheduler = OneCycleLR(
                self.optimizer,
                max_lr=self.learning_rate,
                total_steps=self.training_epoch * len(self.train_dataloader),
                pct_start=0.25,
                anneal_strategy='cos'
            )
        else:
            self.optimizer = None  # No optimizer needed
            self.scheduler = None
            print("Warning: No trainable parameters found in the model. Skipping optimizer initialization.")

        # Loss function
        self.loss_fn = nn.MSELoss(reduction='mean')

        # Initialize global step counter
        self.global_steps = 0

    def load_model(self):
        if self.ckpt_path:
            self.load_checkpoint(self.ckpt_path)
            print("Load success from: ", self.ckpt_path)

    def load_checkpoint(self, checkpoint_path: str):
        """
        Load model and EMA weights from a checkpoint.

        :param checkpoint_path: Path to the checkpoint file.
        """
        if os.path.exists(checkpoint_path):
            ckpt = torch.load(checkpoint_path, map_location=self.device)
            if "model_ema" in ckpt:
                self.model_ema.load_state_dict(ckpt["model_ema"])
            if "model" in ckpt:
                self.model.load_state_dict(ckpt["model"])
            print(f"Loaded checkpoint from {checkpoint_path}.")
        else:
            print(f"Checkpoint not found at {checkpoint_path}.")

    def save_checkpoint(self, checkpoint_path: str):
        """
        Save model and EMA weights to a checkpoint.

        :param checkpoint_path: Path where the checkpoint will be saved.
        """
        ckpt = {
            "model": self.model.state_dict(),
            "model_ema": self.model_ema.state_dict()
        }
        torch.save(ckpt, checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}.")

    def train_epoch(self, epoch: int):
        """
        Train the model for one epoch.

        :param epoch: Current epoch number.
        """
        self.model.train()
        epoch_loss = 0.0

        for batch_idx, images in enumerate(self.train_dataloader):
            if isinstance(images, (list, tuple)):
                images = images[0]
            else:
                images = images
            # images = (images - 50.0) / 50.0
            images = images.to(self.device)

            # Forward pass
            noise = torch.randn_like(images).to(self.device)
            pred = self.model(images, noise)
            loss = self.loss_fn(pred, noise)

            # Backward pass and optimization
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

            # Update EMA
            if self.global_steps % self.model_ema_steps == 0:
                self.model_ema.update_parameters(self.model)

            # Logging
            if self.global_steps % self.log_freq == 0:
                current_lr = self.scheduler.get_last_lr()[0]
                print(
                    f"Epoch [{epoch}/{self.training_epoch}], "
                    f"Step [{batch_idx + 1}/{len(self.train_dataloader)}], "
                    f"Loss: {loss.item():.5f}, LR: {current_lr:.6f}"
                )

            epoch_loss += loss.item()
            self.global_steps += 1

        avg_epoch_loss = epoch_loss / len(self.train_dataloader)
        print(f"Epoch [{epoch}/{self.training_epoch}] completed with average loss: {avg_epoch_loss:.5f}")

    def train(self):
        """
        Execute the training loop over the specified number of epochs.
        """
        training_epoch = self.training_epoch

        for epoch in range(1, training_epoch + 1):
            self.train_epoch(epoch)

            # Save checkpoint
            self.save_checkpoint(self.ckpt_path)

            # Generate and save samples after each epoch
            # self.sample_and_save(self.config.n_samples)


def parse_args():
    """
    Parse command-line arguments for training the SimpleDDPM model.

    :return: Parsed arguments as a Namespace object.
    """
    parser = argparse.ArgumentParser(description="Training SimpleDDPM for MNIST")

    # Training hyperparameters
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')

    # Dataset and task configuration
    parser.add_argument('--dataset_name', type=str, default='MNIST', help='Name of the dataset')
    parser.add_argument('--task', type=str, default='DDPM_codetest', help='Task name')

    # Checkpointing
    parser.add_argument('--ckpt', type=str, default=None, help='Path to checkpoint file')

    # Sampling and model configuration
    parser.add_argument('--n_samples', type=int, default=36, help='Number of samples to generate after each epoch')
    parser.add_argument('--model_base_dim', type=int, default=64, help='Base dimension of the UNet')
    parser.add_argument('--timesteps', type=int, default=1000, help='Number of timesteps in DDPM')

    # EMA configuration
    parser.add_argument('--model_ema_steps', type=int, default=10, help='EMA model update interval (in steps)')
    parser.add_argument('--model_ema_decay', type=float, default=0.995, help='EMA decay rate')

    # Logging
    parser.add_argument('--log_freq', type=int, default=10, help='Logging frequency (in steps)')

    # Sampling options
    parser.add_argument('--no_clip', action='store_true',
                        help='Disable clipping of x₀ during sampling for potentially unstable samples')

    # Device configuration
    parser.add_argument('--cpu', action='store_true', help='Use CPU for training')

    args = parser.parse_args()
    return args


def main():
    """Main function to initialize the trainer and start training."""
    # Parse command-line arguments
    args = parse_args()

    # Load datasets
    train_dataloader, test_dataloader = load_dataset(
        dataset_name=args.dataset_name,
        batch_size=args.batch_size
    )

    # Initialize the trainer with the configuration
    # Initialize model
    from models.diffuser_unet import DiffuserUNet2DModelforMNIST
    denoising_model = DiffuserUNet2DModelforMNIST(
        timesteps=33,
        image_size=28,
        in_channels=1,
        base_dim=64,
        dim_mults=[2, 4]
    )

    trainer = DDPMTrainer(
        learning_rate=args.lr,
        batch_size=args.batch_size,
        training_epoch=args.epochs,
        denoising_module=denoising_model,
        dataset_name=args.dataset_name,
        train_data_loader=train_dataloader,
        test_data_loader=test_dataloader,
        timesteps=args.timesteps,
        device='cuda'
    )

    # Start training
    trainer.train()


if __name__ == "__main__":
    main()
