from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import Subset
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
cudnn.benchmark = True

import time


embeds = torch.load("./embeds.pt")
# print(embeds[9])
print("loaded embeddings")

# exit()


#set manual seed to a constant get a consistent output
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

arr = np.arange(10)
shuffle_pattern = np.roll(arr, shift=1)
print(shuffle_pattern)
target_pattern = shuffle_pattern

#loading the dataset
dataset = dset.CIFAR10(root="./data-cifar10", download=True,
                        transform=transforms.Compose([
                            # transforms.Resize(64),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))
nc=3

transform=transforms.Compose([
    # transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_set = dset.CIFAR10("./data-cifar10", train=True, transform=transform, target_transform=None, download=True)
test_set = dset.CIFAR10("./data-cifar10", train=False, transform=transform, target_transform=None, download=True)


dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                        shuffle=True, num_workers=2)

#checking the availability of cuda devices
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Function to split dataset into subsets for each class
def split_dataset_by_class(dataset):
    # Dictionary to hold indices for each class
    class_indices = {i: [] for i in range(10)}  # There are 10 classes in FashionMNIST

    # Iterate over the dataset and store indices for each class
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    # Create a subset for each class
    subsets = {class_label: Subset(dataset, indices) for class_label, indices in class_indices.items()}
    return subsets

# Split the train_set
train_class_subsets = split_dataset_by_class(train_set)

def plot_samples_from_each_class(subsets, num_samples=5):
    # Setting up the plot - 10 rows (1 for each class), and num_samples columns
    fig, axes = plt.subplots(nrows=10, ncols=num_samples, figsize=(num_samples * 2, 20))
    plt.subplots_adjust(hspace=0.5)

    # Each i is a class, and we access its subset
    for i, subset in subsets.items():
        # DataLoader to load the data from the subset
        loader = torch.utils.data.DataLoader(subset, batch_size=num_samples, shuffle=True)
        images, labels = next(iter(loader))  # Get one batch of images

        for idx in range(num_samples):
            ax = axes[i, idx]
            # Transpose the image from (C, H, W) to (H, W, C) to display it correctly
            image = images[idx].numpy().transpose((1, 2, 0))
            # Since images are in grayscale, we use a gray colormap
            ax.imshow(image.squeeze(), cmap='gray')
            ax.axis('off')  # Turn off axis numbers and ticks
            if idx == 0:
                ax.set_title(f'Class {i}', fontsize=16)

    plt.show()


# number of gpu's available
ngpu = 1
# input noise dimension
nz = 1024
# number of generator filters
ngf = 64
#number of discriminator filters
ndf = 64

lr = 0.0002

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu

        # self.image_encoder = nn.Sequential(
        #     # nn.Conv2d(nc, 16, kernel_size=3, stride=2, padding=1),  # Output: (16, 32, 32)
        #     # nn.ReLU(),
        #     nn.Conv2d(nc, 32, kernel_size=3, stride=2, padding=1), # Output: (32, 16, 16)
        #     nn.ReLU(),
        #     nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # Output: (64, 8, 8)
        #     nn.ReLU(),
        #     nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# Output: (128, 4, 4)
        #     nn.ReLU(),
        #     nn.Flatten(),  # Flatten to vector
        #     nn.Linear(128*4*4, 512),  # Dense layer to produce a vector of size 50
        #     nn.ReLU()
        # )

        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4

            nn.Flatten(),
            nn.Linear(256*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
        )


        self.main = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.ReLU(True),
            nn.Linear(2048, 2048),
            nn.ReLU(True),
            nn.Linear(2048, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 768),
        )
        # self.main = nn.Sequential(
        #     # input is Z, going into a convolution
        #     nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
        #     nn.BatchNorm2d(ngf * 8),
        #     nn.ReLU(True),
        #     # state size. (ngf*8) x 4 x 4
        #     nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
        #     nn.BatchNorm2d(ngf * 4),
        #     nn.ReLU(True),
        #     # state size. (ngf*4) x 8 x 8
        #     nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
        #     nn.BatchNorm2d(ngf * 2),
        #     nn.ReLU(True),
        #     # state size. (ngf*2) x 16 x 16
        #     nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
        #     nn.Tanh(),
        #     # nn.BatchNorm2d(ngf),
        #     # nn.ReLU(True),
        #     # # state size. (ngf) x 32 x 32
        #     # nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
        #     # nn.Tanh()
        #     # state size. (nc) x 64 x 64
        # )

    # def forward(self, input):
    #     if input.is_cuda and self.ngpu > 1:
    #         output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
    #     else:
    #         output = self.main(input)
    #         return output
    def forward(self, noise, images):
        image_features = self.image_encoder(images)
        image_features = image_features.view(-1, 512, 1, 1)
        # print(image_features.shape)
        # print(noise.shape)
        # Concatenate the image features with the noise vector
        combined_features = torch.cat([image_features, noise], dim=1)
        # print(combined_features.shape)
        combined_features = combined_features.view(-1, 1024, 1, 1)
        output = self.main(combined_features.view(-1, 1024))
        # print(output.shape)
        return output


netG = Generator(ngpu).to(device)
netG.apply(weights_init)
#load weights to test the model
#netG.load_state_dict(torch.load('weights/netG_epoch_24.pth'))
print(netG)


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Linear(768, 600),
            nn.ReLU(True),
            nn.Linear(600, 400),
            nn.ReLU(True),
            nn.Linear(400, 100),
            nn.ReLU(True),
            nn.Linear(100, 30),
            nn.ReLU(True),
            nn.Linear(30,1),
            nn.Sigmoid(),

        )
    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)




multi_class_number = 10
batch = 128
data_loader = []
for i in range(multi_class_number):
    data_loader.append(data.DataLoader(train_class_subsets[i], batch_size=batch, shuffle=True,num_workers=2))

test_loader = data.DataLoader(test_set, batch_size=batch, shuffle=True,num_workers=2)


netD = []
for i in range(multi_class_number):
    netD.append(Discriminator(ngpu).cuda())
    netD[i].apply(weights_init)

# netD = Discriminator(ngpu).to(device)
# netD.apply(weights_init)
#load weights to test the model
#netD.load_state_dict(torch.load('weights/netD_epoch_24.pth'))
# print(netD)

criterion = nn.BCELoss()

# setup optimizer
# optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

optimizerD = []
for i in range(multi_class_number):
    optimizerD.append(optim.Adam(netD[i].parameters(), lr=lr, betas=(0.5, 0.999)))

fixed_noise = torch.randn(128, 512, 1, 1, device=device)
real_label = 1
fake_label = 0

niter = 500
g_loss = []
d_loss = []
img_list = []
print("Starting Training Loop...")


for epoch in range(niter):
    for i, data_set in enumerate(zip(*data_loader), 0):
    # for i, data in enumerate(dataloader, 0):
        for j in range(multi_class_number):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD[j].zero_grad()
            real_cpu = data_set[target_pattern[j]][0].to(device)
            embed_label = embeds[target_pattern[j]]
            batch_size = real_cpu.size(0)
            # print(batch_size)
            # batch_size = 128
            label = torch.full((batch_size,), real_label, device=device,dtype=torch.float)
            embedding = embed_label.unsqueeze(0).repeat(batch_size, 1)
            embedding = embedding.to(torch.float32).to(device)
            # print("real_cpu")
            # print(real_cpu.shape)
            # print(embedding.shape)
            output = netD[j](embedding)
            # output = torch.squeeze(output)

            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()


            # print("=======")
            # train with fake
            noise = torch.randn(batch_size, 512, 1, 1, device=device)
            # print(noise.shape)
            input_image = data_set[j][0].to(device)
            # print(input_image.shape)
            t0 = time.time()
            fake = netG(noise, input_image)
            t1 = time.time()
            print(t1-t0)
            label.fill_(fake_label)
            # print(fake.shape)
            output = netD[j](fake.detach())
            # print(label.shape)
            # print("output")
            # print(output.shape)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD[j].step()


            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD[j](fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, niter, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # #save the output
            # if i % 100 == 0:
            #     print('saving the output')
            #     vutils.save_image(real_cpu,'./cifar_gan/output/1_real_samples.png',normalize=True)
            #     imgs = data_set[j][0][0].to(device)
            #     nis = torch.randn(1, int(nz/2), 1, 1, device=device)
            #     fake = netG(nis, imgs)
            #     vutils.save_image(fake.detach(),'./cifar_gan/output/1_fake_samples_epoch_%03d_%03d.png' % (epoch, j),normalize=True)

    if epoch % 100 == 0:
        for k in range(multi_class_number):
            img_list = []
            with torch.no_grad():
                images = data_set[k][0][0:8].to(device)
                fixed_noise = torch.randn(8, 512, 1, 1).to(device)
                fake = netG(fixed_noise, images).detach().cpu()
                vutils.save_image(fake, f'/home/gan/cifar_gan/g_images/baseline_0620_generated_image_epoch_{epoch}_{k}.png', normalize=True)
        # img_list.append(vutils.make_grid(fake, normalize=True))
        # for item in img_list:
        #     im = transforms.ToPILImage()(item)
        #     # print(k)
        # plt.imshow(im)
        # plt.savefig(f'./cifar_gan/output/1_torch_cifar_gan_generated_image_epoch_{epoch}_iter_{i}_{k}.png')

    # Check pointing for every epoch
    torch.save(netG.state_dict(), '/mnt/2tb/cifar_gan/input_32/g_model/baseline_G_0620_epoch_%d.pth' % (epoch))
    # torch.save(netD.state_dict(), './cifar_gan/weights/netD_epoch_%d.pth' % (epoch))

gloss_np = np.array(g_loss)
dloss_np = np.array(d_loss)

np.save('/mnt/2tb/cifar_gan/input_32/baseline_0620_gloss.npy', gloss_np)
np.save('/mnt/2tb/cifar_gan/input_32/baseline_0620_dloss.npy', dloss_np)


# torch.save(netG, './cifar_gan/torch_cifar_cGAN_G_0508')

for i in range(10):
    torch.save(netD[i], '/mnt/2tb/cifar_gan/input_32/d_model/baseline_D_0620_model_' + str(i))