# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# ============================================================================
# File description: Realize the model training function.
# ============================================================================
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from diffusion import Diffusion
import matplotlib.pyplot as plt  # Import Matplotlib for visualization
import numpy as np
from config import *
import time

diffusion = Diffusion()

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.epsilon = 0.01  # Use a float for epsilon, which will be converted to tensor in forward pass

    def forward(self, predicted: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
        # Convert epsilon to a tensor with the same data type and device as the input tensors
        epsilon = torch.tensor(self.epsilon, dtype=predicted.dtype, device=predicted.device)

        # Compute the custom loss
        squared_diff = (predicted - true) ** 2
        loss = torch.sqrt(squared_diff + epsilon) - torch.sqrt(epsilon)

        # Return the mean loss over the batch
        return torch.mean(loss)

criterion = CustomLoss()
def train(dataloader, epoch, G_losses, D_losses) -> None:
    """Training generative models and adversarial models.

    Args:
        dataloader (torch.utils.data.DataLoader): The loader of the training dataset.
        epoch (int): number of training cycles.
    """
    # Calculate how many iterations there are under epoch.
    batches = len(dataloader)
    # Set two models in training mode.
    discriminator.train()
    generator.train()

    start_time=time.time()

    for index, (real, _) in enumerate(dataloader):
        # Copy the data to the specified device.
        real = real.to(device)
        label_size = real.size(0)
        # Create label. Set the real sample label to 1, and the fake sample label to 0.
        real_label = torch.full([label_size, 1], 1.0, dtype=real.dtype, device=device)
        #print(f"real_label {real_label.shape}")
        fake_label = torch.full([label_size, 1], 0.0, dtype=real.dtype, device=device)
       # print(f"fake_label {fake_label.shape}")
        # Create an image that conforms to the Gaussian distribution.
        noise = torch.randn([label_size, 100, 1, 1], device=device)

         # Initialize the discriminator model gradient.
        discriminator.zero_grad()
        noisy_input = diffusion(real)
        
        # Calculate the loss of the discriminator model on the real image.
        output = discriminator(noisy_input)
        x = output
        #print(f"output {output.shape}")
        criterion = CustomLoss()
        d_loss_real = criterion(output, real_label)
        d_loss_real.backward()
        d_real = output.mean().item()
        # Generate a fake image.
        
                
        fake = generator(noise)
        noisy_input1 = diffusion(fake)
        #print(f"fake_shape {fake.shape}")
        # Calculate the loss of the discriminator model on the fake image.
        output = discriminator(noisy_input1.detach())
        #print(f"output_fake_detach {output.shape}")
        d_loss_fake = criterion(output, fake_label)
        d_loss_fake.backward()
        d_fake1 = output.mean().item()
        # Update the weights of the discriminator model.
        d_loss = 0.5*(d_loss_real + d_loss_fake)
        d_optimizer.step()
        # Calculating the metrics FID, Inception score
        # print(fake.shape)
        # print(f"Using device: {device}")
        #get_all_score(real, fake, epoch, index)


        # Initialize the generator model gradient.
        generator.zero_grad()
        # Calculate the loss of the discriminator model on the fake image.
        output = discriminator(noisy_input1)
        # Adversarial loss.
        g_loss = (criterion(output, x.detach()))
        # Update the weights of the generator model.
        g_loss.backward(retain_graph=True)
        g_optimizer.step()
        
        d_fake2 = output.mean().item()
        # Write the loss during training into Tensorboard.
        iters = index + epoch * batches + 1
        writer.add_scalar("Train_Adversarial/D_Loss", d_loss.item(), iters)
        writer.add_scalar("Train_Adversarial/G_Loss", g_loss.item(), iters)
        writer.add_scalar("Train_Adversarial/D_Real", d_real, iters)
        writer.add_scalar("Train_Adversarial/D_Fake1", d_fake1, iters)
        writer.add_scalar("Train_Adversarial/D_Fake2", d_fake2, iters)
        # Print the loss function every ten iterations and the last iteration in this epoch.
        if (index + 1) % 10 == 0 or (index + 1) == batches:
            print(f"Train stage: adversarial "
                  f"Epoch[{epoch + 1:04d}/{epochs:04d}]({index + 1:05d}/{batches:05d}) "
                  f"D Loss: {d_loss.item():.6f} G Loss: {g_loss.item():.6f} "
                  f"D(Real): {d_real:.6f} D(Fake1)/D(Fake2): {d_fake1:.6f}/{d_fake2:.6f}.")
        # Store losses as scalars
        G_losses.append(g_loss.item())  # Store generator loss
        D_losses.append(d_loss.item())  # Store discriminator loss    
    end_time=time.time()
    duration=end_time-start_time
    print(f"epoch [{epoch+1}/{epochs}] completed in {duration:.2f} seconds.")

# Visualization
def visualize_losses(g_loss, d_loss):
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_loss, label="G")
    plt.plot(d_loss, label="D")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f"convergene.png")  # Adjust the file name as needed
    plt.show()
    plt.close()    

from calculate_extra_scores import get_all_score  # Import your score function
#from torchvision.datasets import MNIST
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
import pickle
import numpy as np
from PIL import Image
from inception import inception_score

def main() -> None:


    # Create lists to store losses
    G_losses = []
    D_losses = []
    # Create a experiment result folder.
    if not os.path.exists(exp_dir1):
        os.makedirs(exp_dir1)
    if not os.path.exists(exp_dir2):
        os.makedirs(exp_dir2)

    # Create an image that conforms to the Gaussian distribution.
    fixed_noise = torch.randn([64, 100, 1, 1], device=device)

    # Load dataset.
    dataset = torchvision.datasets.CIFAR10(root=dataset_dir,
                                         train=True,
                                         transform=transforms.Compose([
                                             transforms.Resize([image_size, image_size]),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
                                         download=True)
    dataloader = DataLoader(dataset, batch_size, True, pin_memory=True)
    # Check whether the training progress of the last abnormal end is restored, for example, the power is
    # cut off in the middle of the training.
    if resume:
        print("Resuming...")
        if resume_d_weight != "" and resume_g_weight != "":
            discriminator.load_state_dict(torch.load(resume_d_weight))
            generator.load_state_dict(torch.load(resume_g_weight))

    for epoch in range(start_epoch, epochs):
        # Train each epoch to generate a model.
        train(dataloader, epoch, G_losses, D_losses)
        # Save the weight of the model under epoch.
        torch.save(discriminator.state_dict(), os.path.join(exp_dir1, f"d_epoch{epoch + 1}.pth"))
        torch.save(generator.state_dict(), os.path.join(exp_dir1, f"g_epoch{epoch + 1}.pth"))

        # Each epoch validates the model once.
        with torch.no_grad():
            # Switch model to eval mode.
            generator.eval()
            fake = generator(fixed_noise).detach()
            torchvision.utils.save_image(fake, os.path.join(exp_dir1, f"epoch_{epoch + 1}.bmp"), normalize=True)

    # Save the weight of the model under the last Epoch in this stage.
    torch.save(discriminator.state_dict(), os.path.join(exp_dir2, "d-last.pth"))
    torch.save(generator.state_dict(), os.path.join(exp_dir2, "g-last.pth"))
    
        # Visualize losses after training
    # visualize_losses(G_losses, D_losses)
    


if __name__ == "__main__":
    main()
