from __future__ import print_function
#%matplotlib inline
import argparse

import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Subset
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
import torch.nn.functional as F
import torchvision.transforms.functional as TF
# import wandb
import os
import time

# Create the parser
parser = argparse.ArgumentParser(description="In control of shifts")

# Add an argument
parser.add_argument('shifts', type=int, help='Put how many shifts you want')
# Parse the arguments
args = parser.parse_args()

shifts = args.shifts



# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

arr = np.arange(10)
shuffle_pattern = np.roll(arr, shift=shifts)
print(shuffle_pattern)
target_pattern = shuffle_pattern

workers = 2
batch=32
im_size = 28
nc = 1
nz = 100
ngf = 64
ndf = 64
lr = 0.0002
lr = 0.0002
ngpu = 0

#train_transform =
transform = transforms.Compose([
    # transforms.Resize([64, 64]),
    transforms.ToTensor(),
])
train_set = dset.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
test_set = dset.MNIST('./data', train=False, transform=transform, target_transform=None, download=True)


# Function to split dataset into subsets for each class
def split_dataset_by_class(dataset):
    # Dictionary to hold indices for each class
    class_indices = {i: [] for i in range(10)}  # There are 10 classes in FashionMNIST

    # Iterate over the dataset and store indices for each class
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    # Create a subset for each class
    subsets = {class_label: Subset(dataset, indices) for class_label, indices in class_indices.items()}
    return subsets

# Split the train_set
train_class_subsets = split_dataset_by_class(train_set)

def plot_samples_from_each_class(subsets, num_samples=5):
    # Setting up the plot - 10 rows (1 for each class), and num_samples columns
    fig, axes = plt.subplots(nrows=10, ncols=num_samples, figsize=(num_samples * 2, 20))
    plt.subplots_adjust(hspace=0.5)

    # Each i is a class, and we access its subset
    for i, subset in subsets.items():
        # DataLoader to load the data from the subset
        loader = torch.utils.data.DataLoader(subset, batch_size=num_samples, shuffle=True)
        images, labels = next(iter(loader))  # Get one batch of images

        for idx in range(num_samples):
            ax = axes[i, idx]
            # Transpose the image from (C, H, W) to (H, W, C) to display it correctly
            image = images[idx].numpy().transpose((1, 2, 0))
            # Since images are in grayscale, we use a gray colormap
            ax.imshow(image.squeeze(), cmap='gray')
            ax.axis('off')  # Turn off axis numbers and ticks
            if idx == 0:
                ax.set_title(f'Class {i}', fontsize=16)

    plt.show()


# class convNet(nn.Module):
#     def __init__(self):
#         super(convNet,self).__init__()
#         self.conv1=nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,padding=1,stride=1)
#         self.conv2=nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,padding=1,stride=1)
#         self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
#
#
#         self.fc1=nn.Linear(7*7*32,512)
#         self.fc2=nn.Linear(512,256)
#         self.out=nn.Linear(256,10)
#         self.dropout=nn.Dropout(0.2)
#
#     def forward(self,x):
#         x=self.pool(F.relu(self.conv1(x)))
#         x=self.pool(F.relu(self.conv2(x)))
#         x=x.view(-1,7*7*32)
#         x = self.dropout(x)
#         x=self.dropout(F.relu(self.fc1(x)))
#         x=self.dropout(F.relu(self.fc2(x)))
#         x=self.out(x)
#         return x
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 300)
        self.fc3 = nn.Linear(300, 100)
        self.fc4 = nn.Linear(100, 10)


    def forward(self, x):

        x = x.view(-1, 784)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

classifier=MLP()
classifier.load_state_dict(torch.load('/home/xiang/gan/out_place/train_together/mnist/mlp_mnist/mnist_classifier.pth'))
classifier_criterion=nn.CrossEntropyLoss()
classifier.cuda()



class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu

        self.image_encoder = nn.Sequential(
            # nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1),  # Output: (16, 28, 28)
            # nn.ReLU(),
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1), # Output: (32, 14, 14)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=2), # Output: (64, 8, 8)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),# Output: (128, 4, 4)
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),# Output: (256, 2, 2)
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=2, stride=1, padding=0),# Output: (512, 1, 1)
            nn.ReLU(),
            # nn.Flatten(), # Flatten to vector
            # nn.Linear(256*2*2, 512),  # Dense layer to produce a vector of size 50
            # nn.ReLU()
        )


        # self.main = nn.Sequential(
        #     nn.ConvTranspose2d( 1, 64 * 16, 2, 1, 0, bias=False),
        #     nn.BatchNorm2d(64 * 16),
        #     nn.ReLU(True),
        #     nn.ConvTranspose2d(64 * 16, 64 * 8, 4, 2, 1, bias=False),
        #     nn.BatchNorm2d(64 * 8),
        #     nn.ReLU(True),
        #     nn.ConvTranspose2d( 64 * 8, 64 * 4, 4, 2, 1, bias=False),
        #     nn.BatchNorm2d(64 * 4),
        #     nn.ReLU(True),
        #     nn.ConvTranspose2d( 64 * 4, 64 * 2, 4, 2, 1, bias=False),
        #     nn.BatchNorm2d(64),
        #     nn.ReLU(True),
        #     nn.ConvTranspose2d( 64, 1, 4, 2, 1, bias=False),
        #     nn.Tanh()
        # )
        self.main = nn.Sequential(
            nn.ConvTranspose2d( 1024, ngf * 16, 2, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 16, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 3, 2, 1, bias=False),  # 7x7
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ngf),
            # nn.ReLU(True),
            nn.ConvTranspose2d( ngf * 2, nc, 4, 2, 1, bias=False),   # 28x28
            nn.Tanh()
        )

    def forward(self, noise, images):
        image_features = self.image_encoder(images)
        # print(image_features.shape)
        image_features = image_features.view(-1, 512)
        # print(image_features.shape)
        # print(noise.shape)
        # Concatenate the image features with the noise vector
        combined_features = torch.cat([image_features, noise], dim=1)
        # print(combined_features.shape)
        combined_features = combined_features.view(-1, 1024, 1, 1)
        # print(combined_features.shape)
        output = self.main(combined_features)
        return output


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 28 x 28
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4, 2, 2, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8
            # nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ndf * 4),
            # nn.LeakyReLU(0.2, inplace=True),
            # # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 2, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


multi_class_number = 10

data_loader = []
for i in range(multi_class_number):
    data_loader.append(data.DataLoader(train_class_subsets[i], batch_size=batch, shuffle=True,num_workers=workers))

test_loader = data.DataLoader(test_set, batch_size=batch, shuffle=True,num_workers=workers)


netG = Generator(ngpu).cuda()

netD = []
for i in range(multi_class_number):
    netD.append(Discriminator(ngpu).cuda())


criterion = nn.BCELoss()
# fixed_noise = torch.randn(4, 50)
real_label = 1
fake_label = 0
# optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

optimizerD = []
for i in range(multi_class_number):
    optimizerD.append(optim.Adam(netD[i].parameters(), lr=lr, betas=(0.5, 0.999)))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Lists to keep track of progress
scaler = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
img_list = []
G_losses = []
D_losses = []
iters = 0
num_epochs=30

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    # for i, (data, target_data) in enumerate(zip(train_loader, target_loader), 0):
    for i, data_set in enumerate(zip(*data_loader), 0):

        for j in range(multi_class_number):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD[j].zero_grad()
            # Format batch
            real_cpu = data_set[target_pattern[j]][0].to(device)
            # print(target_pattern[j])
            # real_cpu = data_set[j][0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, device=device).float()
            # Forward pass real batch through D
            output = netD[j](real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()


            ## Train with all-fake batch
            # Generate batch of latent vectors
            # noise = torch.randn(b_size, 50, 1, 1, device=device)
            noise = torch.randn(b_size, 512, device=device)
            # Generate fake image batch with G
            input_image = data_set[j][0].to(device)
            fake = netG(noise, input_image)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD[j](fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD[j].step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD[j](fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G

            # errG.backward()
            D_G_z2 = output.mean().item()

            # for param in netG.parameters():
            #     if param.grad is not None:
            #         param.grad *= scaler[j]

            # fake_re = TF.resize(fake, (28,28))
            # predicted = classifier(fake_re).view(-1).to(device)
            # target_label = torch.tensor(target_pattern[j]).to(device)
            # target_label_batch = target_label.repeat(batch, 1)
            # errClassifier = classifier_criterion(predicted, target_label)
            # allG = 0.1 * errClassifier + errG
            allG = errG
            allG.backward()

            # Update G
            optimizerG.step()
            optimizerG.zero_grad()

            # update scaler
            # scaler[j] = errG

            # Output training stats
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, num_epochs, i, len(data_loader[0]),
                        errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

        # # Check how the generator is doing by saving G's output on fixed_noise
        # if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(data_loader[0])-1)):
        #     # img_list = []
        #     for k in range(multi_class_number):
        #         # img_list = []
        #         with torch.no_grad():
        #             images = data_set[k][0].to(device)
        #             fixed_noise = torch.randn(32, 512).to(device)
        #             fake = netG(fixed_noise, images).detach().cpu()
        #         img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        #         for item in img_list:
        #             im = transforms.ToPILImage()(item)
        #             # print(k)
        #         plt.imshow(im)
        #         plt.savefig(f'./GAN_test_results/Shift_1_torch_fashion_gan_generated_image_epoch_{epoch}_iter_{i}_{k}.png')
        # iters += 1
    torch.save(netG, f'/mnt/2tb/xiang/28_gan/input_28/g_model/mnist_baseline_G_0618_shift_{shifts}_epoch_{epoch}')


gloss_np = np.array(G_losses)
dloss_np = np.array(D_losses)

np.save(f'/mnt/2tb/xiang/28_gan/input_28/loss/mnist_baseline_0618_shift_{shifts}_gloss.npy', gloss_np)
np.save(f'/mnt/2tb/xiang/28_gan/input_28/loss/mnist_baseline_0618_shift_{shifts}_dloss.npy', dloss_np)

# torch.save(netG, './torch_fashion_cGAN_G_0604_shift_9_deep')

for i in range(10):
    torch.save(netD[i], f'/mnt/2tb/xiang/28_gan/input_28/d_model/mnist_baseline_D_0618_shift_{shifts}_d_' + str(i))
