import os
import cv2
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
import seaborn as sns
from torch import optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tensorflow.python.keras.models import load_model
from model import Discriminator, Generator
from torchvision.datasets import MNIST

LEARNING_RATE = 0.0005  # Learning rate for optimizer
BATCH_SIZE = 256  # Size of each batch of data
IMAGE_SIZE = 64  # Size of the images (64*64*1)
EPOCHS = 100  # Number of training epochs
noise_channels = 256  # Size of the latent dimension
image_channels = 1  # Number of image channels
gen_features = 64  # Number of generator features
disc_features = 64  # Number of discriminator features

device = torch.device('cuda')  # Set the device to CUDA (GPU)

# Initialize the generator and discriminator models
gen_model = Generator(noise_channels, image_channels, gen_features).to(device)
disc_model = Discriminator(image_channels, disc_features).to(device)

criterion = nn.BCELoss()  # Binary Cross Entropy Loss for training

# DataLoader for MNIST dataset
data_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )),
])

dataset = MNIST(root='./dataset/',
                train=True,
                transform=data_transforms,
                download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Optimizers for generator and discriminator
gen_optimizer = optim.Adam(gen_model.parameters(),
                           lr=LEARNING_RATE,
                           betas=(0.5, 0.999))
disc_optimizer = optim.Adam(disc_model.parameters(),
                            lr=LEARNING_RATE,
                            betas=(0.5, 0.999))

gen_model.train()
disc_model.train()

# Labels for fake and real images
fake_label = 0
real_label = 1

pred_model = load_model('./classifier.hdf5')


def predict(img):
    image = img.copy()
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # Convert to grayscale
    image = cv2.resize(image, (28, 28))  # Resize to 28x28
    image = image.astype('float32')
    image = image.reshape(1, 28, 28, 1)
    image /= 255
    pred = pred_model.predict(image.reshape(1, 28, 28, 1), batch_size=1)
    return pred > 1e-2


# Create folder for saving heatmaps
os.makedirs('./MNIST_heatmaps/', exist_ok=True)

# Main training loop
for epoch in range(EPOCHS):
    record_mode_mixture = 1 / 2 * np.ones(
        (10, 10))  # Initialize the mode mixture matrix
    for batch_idx, (data, target) in enumerate(dataloader):

        # Loading real images and moving them to the device
        imgs = data.to(device)
        batch_size = data.shape[0]

        # Training the discriminator with real images
        disc_model.zero_grad()
        label = torch.ones(batch_size).to(device)
        output = disc_model(imgs).reshape(-1)
        real_disc_loss = criterion(output, label)
        d_x = output.mean().item()

        # Training the discriminator with fake images
        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        label = torch.zeros(batch_size).to(device)
        output = disc_model(fake.detach()).reshape(-1)
        fake_disc_loss = criterion(output, label)

        # Computing the total discriminator loss and performing backpropagation
        disc_loss = real_disc_loss + fake_disc_loss
        disc_loss.backward()
        disc_optimizer.step()

        # Training the generator
        gen_model.zero_grad()
        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        label = torch.ones(batch_size).to(device)
        output = disc_model(fake).reshape(-1)
        gen_loss = criterion(output, label)
        gen_loss.backward()
        gen_optimizer.step()

        # Printing the progress
        print(f'Epoch: {epoch} ===== Batch: {batch_idx}/{len(dataloader)}')

        if batch_idx < 1:
            fake = gen_model(torch.randn(256, noise_channels, 1, 1).to(device))
            # Loop through generated images
            for index in range(1, imgs.shape[0]):
                # Save a generated image for class prediction
                save_image(fake.data[index],
                           './MNIST_heatmaps/image_for_prediction.png',
                           nrow=1,
                           normalize=True)
                # Predict the class of the generated image using the pre-trained model
                pred = predict(
                    cv2.imread('./MNIST_heatmaps/image_for_prediction.png')
                ).squeeze()
                # Extract indices of predicted classes
                list_i = [i for i, x in enumerate(pred) if x]
                # Increment the count in the mode mixture matrix
                record_mode_mixture[list_i[0]][list_i[-1]] += 1

            # After processing all images in the batch, create a figure for heatmap
            fig = plt.figure(figsize=(5, 5))
            ax = fig.add_subplot(1, 1, 1)
            # Plot a heatmap of the mode mixture matrix with logarithmic scaling
            p1 = sns.heatmap(np.log(record_mode_mixture +
                                    record_mode_mixture.T),
                             annot=True,
                             ax=ax,
                             cmap='YlGnBu',
                             vmin=0,
                             vmax=4,
                             cbar=True)
            plt.text(5, 11.3, f'{epoch}', fontsize=20, ha='center')
            # Save the heatmap for the current epoch
            fig.savefig(f'./MNIST_heatmaps/epoch_{epoch}.pdf',
                        transparent=True)
            plt.close()
