import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision.utils import save_image
from model import Discriminator, Generator, Classifier

LEARNING_RATE = 0.0005  # Learning rate for optimizer
BATCH_SIZE = 256  # Size of each batch of data
IMAGE_SIZE = 64  # Size of the images (64*64*1)
EPOCHS = 120  # Number of training epochs
image_channels = 1  # Number of image channels
noise_channels = 256  # Size of the latent dimension
gen_features = 64  # Number of generator features
disc_features = 64  # Number of discriminator features

device = torch.device('cuda')  # Set the device to CUDA (GPU)

# Initialize the generator and discriminator models
gen_model = Generator(noise_channels, image_channels, gen_features).to(device)
disc_model = Discriminator(image_channels, disc_features).to(device)

criterion = nn.BCELoss()  # Binary Cross Entropy Loss for training

# DataLoader for Fashion MNIST dataset
data_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, )),
])

dataset = FashionMNIST(root='./dataset/',
                       train=True,
                       transform=data_transforms,
                       download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Optimizers for generator and discriminator
gen_optimizer = optim.Adam(gen_model.parameters(),
                           lr=LEARNING_RATE,
                           betas=(0.5, 0.999))
disc_optimizer = optim.Adam(disc_model.parameters(),
                            lr=LEARNING_RATE,
                            betas=(0.5, 0.999))

pred_model = torch.load('classifier.pth')

pred_model.cuda()
pred_model.eval()
gen_model.train()
disc_model.train()

# Labels for fake and real images
fake_label = 0
real_label = 1

fashion_class = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'
}

# Create folder for saving heatmaps
os.makedirs('./Fashion_MNIST_heatmaps/', exist_ok=True)

# Main training loop
for epoch in range(EPOCHS):
    for batch_idx, (data, target) in enumerate(dataloader):

        # Loading real images and moving them to the device
        imgs = data.to(device)
        batch_size = data.shape[0]

        # Training the discriminator with real images
        disc_model.zero_grad()
        label = torch.ones(batch_size).to(device)
        output = disc_model(imgs).reshape(-1)
        real_disc_loss = criterion(output, label)
        d_x = output.mean().item()

        # Training the discriminator with fake images
        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        label = torch.zeros(batch_size).to(device)
        output = disc_model(fake.detach()).reshape(-1)
        fake_disc_loss = criterion(output, label)

        # Computing the total discriminator loss and performing backpropagation
        disc_loss = real_disc_loss + fake_disc_loss
        disc_loss.backward()
        disc_optimizer.step()

        # Training the generator
        gen_model.zero_grad()
        noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
        fake = gen_model(noise)
        label = torch.ones(batch_size).to(device)
        output = disc_model(fake).reshape(-1)
        gen_loss = criterion(output, label)
        gen_loss.backward()
        gen_optimizer.step()

        # Printing the progress
        print(f'Epoch: {epoch} ===== Batch: {batch_idx}/{len(dataloader)}')

        # Saving generated images and calculating Frobenius norm every few batches
        if batch_idx % 47 == 0:
            fake = gen_model(torch.randn(256, noise_channels, 1, 1).to(device))
            fake = F.interpolate(fake,
                                 size=(28, 28),
                                 mode='bilinear',
                                 align_corners=False)
            record_mode_mixture = torch.ones(10, 10) * 0.5
            probability = torch.exp(pred_model(fake.view(-1, 784))) > 1e-2

            for index in range(probability.size(0)):
                list_i = [i for i, x in enumerate(probability[index]) if x]

                record_mode_mixture[list_i[0]][list_i[-1]] += 1

            result = torch.log(record_mode_mixture + record_mode_mixture.T)

            plt.figure(figsize=(5, 5))
            sns.heatmap(result,
                        cmap='YlGnBu',
                        annot=True,
                        fmt='.1f',
                        xticklabels=range(10),
                        yticklabels=range(10),
                        cbar=False)
            plt.savefig(
                f'./Fashion_MNIST_heatmaps/epoch_{epoch}_batchidx_{int(batch_idx/47)}.pdf',
                transparent=True)
            plt.close()
