import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid, save_image
from model import Discriminator, Generator, Noised_Discriminator

LEARNING_RATE = 0.0005  # Learning rate for optimizer
BATCH_SIZE = 256  # Size of each batch of data
IMAGE_SIZE = 64  # Size of the images (64*64*1)
EPOCHS = 120  # Number of training epochs
noise_channels = 256  # Size of the latent dimension
image_channels = 1  # Number of image channels
gen_features = 64  # Number of generator features
disc_features = 64  # Number of discriminator features

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')  # Set the device to CUDA (GPU)

# Initialize the generator and discriminator models
gen_model = Generator(noise_channels, image_channels, gen_features).to(device)
gen_model_noised = Generator(noise_channels, image_channels,
                             gen_features).to(device)
disc_model = Discriminator(image_channels, disc_features).to(device)
disc_model_noised = Noised_Discriminator(image_channels,
                                         disc_features).to(device)

criterion = nn.BCELoss()  # Binary Cross Entropy Loss for training

# DataLoader for Fashion MNIST dataset
data_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )),
])

dataset = FashionMNIST(root='./dataset/',
                       train=True,
                       transform=data_transforms,
                       download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Optimizers for generator and discriminator
gen_optimizer = optim.Adam(gen_model.parameters(),
                           lr=LEARNING_RATE,
                           betas=(0.5, 0.999))
gen_optimizer_noised = optim.Adam(gen_model_noised.parameters(),
                                  lr=LEARNING_RATE,
                                  betas=(0.5, 0.999))
disc_optimizer = optim.Adam(disc_model.parameters(),
                            lr=LEARNING_RATE,
                            betas=(0.5, 0.999))
disc_optimizer_noised = optim.Adam(disc_model_noised.parameters(),
                                   lr=LEARNING_RATE,
                                   betas=(0.5, 0.999))

gen_model.train()
gen_model_noised.train()
disc_model.train()
disc_model_noised.train()

# Labels for fake and real images
fake_label = 0
real_label = 1

# Create folder for saving generated images and histograms
os.makedirs('./noised_Fashion_MNIST_generated_images/', exist_ok=True)
os.makedirs('./Fashion_MNIST_histograms/', exist_ok=True)

# Main training loop
for epoch in range(EPOCHS):
    for batch_idx, (data, target) in enumerate(dataloader):

        # Loading real images and moving them to the device
        imgs = data.to(device)
        batch_size = data.shape[0]

        # Training the discriminators with real images
        disc_model.zero_grad()
        disc_model_noised.zero_grad()

        label = torch.ones(batch_size).to(device)
        output = disc_model(imgs).reshape(-1)
        output_noised = disc_model_noised(imgs).reshape(-1)
        real_disc_loss = criterion(output, label)
        real_disc_loss_noised = criterion(output_noised, label)
        d_x = output.mean().item()
        d_x_noised = output_noised.mean().item()

        # Training the discriminators with fake images
        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        fake_noised = gen_model_noised(noise)

        label = torch.zeros(batch_size).to(device)
        output = disc_model(fake.detach()).reshape(-1)
        output_noised = disc_model_noised(fake_noised.detach()).reshape(-1)
        fake_disc_loss = criterion(output, label)
        fake_disc_loss_noised = criterion(output_noised, label)

        # Computing the total discriminator losses and performing backpropagation
        disc_loss = real_disc_loss + fake_disc_loss
        disc_loss.backward()
        disc_optimizer.step()

        disc_loss_noised = real_disc_loss_noised + fake_disc_loss_noised
        disc_loss_noised.backward()
        disc_optimizer_noised.step()

        # Training the generators
        gen_model.zero_grad()
        gen_model_noised.zero_grad()

        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        fake_noised = gen_model_noised(noise)

        label = torch.ones(batch_size).to(device)

        output = disc_model(fake).reshape(-1)
        output_noised = disc_model_noised(fake_noised).reshape(-1)

        gen_loss = criterion(output, label)
        gen_loss.backward()
        gen_optimizer.step()

        gen_loss_noised = criterion(output_noised, label)
        gen_loss_noised.backward()
        gen_optimizer_noised.step()

        # Printing the progress
        print(f'Epoch: {epoch} ===== Batch: {batch_idx}/{len(dataloader)}')

        # Saving generated images and calculating Frobenius norm every few batches
        if batch_idx % 235 == 0:
            fake = gen_model(torch.randn(256, noise_channels, 1, 1).to(device))
            fake_noised = gen_model_noised(
                torch.randn(64, noise_channels, 1, 1).to(device))

            combined_grid = torch.cat([fake[:5], fake_noised[:5]], dim=0)
            combined_grid = make_grid(combined_grid,
                                      nrow=1,
                                      normalize=True,
                                      padding=2)

            # Save the combined grid as an image
            save_image(
                combined_grid,
                f'./noised_Fashion_MNIST_generated_images/epoch_{epoch}.jpeg')

            # Code for calculating Frobenius norms
            gen_imgs = fake.detach()
            gen_imgs_t = gen_imgs.clone()
            gen_imgs_t.requires_grad_(True)
            gen_imgs_t.grad = torch.zeros_like(gen_imgs_t)
            score = disc_model(gen_imgs_t).reshape(-1)
            score.backward(torch.ones_like(score))
            direction_y = gen_imgs_t.grad / score.unsqueeze(1).unsqueeze(
                2).unsqueeze(3).expand(-1, -1, 64, 64)

            # Calculating and recording Frobenius norms
            flat_direction_y = direction_y.view(256, -1)
            norms = torch.norm(flat_direction_y, p='fro', dim=1)
            norms_np = norms.cpu().detach().numpy()

            gen_imgs_noised = fake_noised.detach()
            gen_imgs_t_noised = gen_imgs_noised.clone()
            gen_imgs_t_noised.requires_grad_(True)
            gen_imgs_t_noised.grad = torch.zeros_like(gen_imgs_t_noised)
            score_noised = disc_model(gen_imgs_t_noised).reshape(-1)
            score_noised.backward(torch.ones_like(score_noised))
            direction_y_noised = gen_imgs_t_noised.grad / score_noised.unsqueeze(
                1).unsqueeze(2).unsqueeze(3).expand(-1, -1, 64, 64)

            # Calculating and recording Frobenius norms
            flat_direction_y_noised = direction_y_noised.view(256, -1)
            norms_noised = torch.norm(flat_direction_y_noised, p='fro', dim=1)
            norms_np_noised = norms_noised.cpu().detach().numpy()

            plt.figure(figsize=(8, 6))
            sns.histplot(norms_np_noised,
                         label='With Noise',
                         color='tab:red',
                         alpha=0.7,
                         legend=False,
                         stat='density')

            sns.histplot(norms_np,
                         label='Without Noise',
                         color='tab:blue',
                         alpha=0.7,
                         legend=False,
                         stat='density')

            plt.text(0.5,
                     -0.1,
                     f'{epoch}',
                     ha='center',
                     va='center',
                     transform=plt.gca().transAxes,
                     fontsize=40)

            # Calculate the 0.9 percentile
            percentile_90_np = np.percentile(norms_np, 90)
            percentile_90_np_noised = np.percentile(norms_np_noised, 90)

            # Plot vertical lines for the 0.9 percentile
            plt.axvline(x=percentile_90_np,
                        color='tab:blue',
                        linestyle='--',
                        alpha=1,
                        label='90th percentile (Without Noise)')
            plt.axvline(x=percentile_90_np_noised,
                        color='tab:red',
                        linestyle='--',
                        alpha=1,
                        label='90th percentile (With Noise)')
            plt.gca().axes.yaxis.set_visible(False)
            plt.yticks(fontsize=20)
            plt.xscale('log')
            plt.ylabel('')
            plt.savefig(f'./Fashion_MNIST_histograms/epoch_{epoch}.pdf',
                        transparent=True)
            plt.close()
