import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageNet
import torch.utils.data
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Subset
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.models import Swin_V2_T_Weights, swin_v2_t
import os
from PIL import Image

import torchvision.models as models
from torchvision.models import ConvNeXt_Tiny_Weights

from sklearn.metrics import mutual_info_score

from torch.utils.data import random_split
import random

# Data transformations
transforms_ = transforms.Compose([
    # transforms.Resize((64, 64)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    # transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# transforms_normal = transforms.Compose([
#     transforms.Resize((64, 64)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
#     # transforms.ToPILImage(),
#     # transforms.Resize((224, 224)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# Load the dataset
gen_dataset = datasets.ImageFolder(root='/mnt/2tb//hugging_face/gen_images/', transform=transforms_)
reverse_dict = {value: key for key, value in gen_dataset.class_to_idx.items()}
# Create a DataLoader to load the data in batches
# gen_dataloader = DataLoader(gen_dataset, batch_size=32, shuffle=True, num_workers=4)
embeds = torch.load("./embeds.pt")

print("loaded embeddings")

#set manual seed to a constant get a consistent output
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

arr = np.arange(20)
shuffle_pattern = np.roll(arr, shift=1)
print(shuffle_pattern)
target_pattern = shuffle_pattern


def custom_label_transform(label):
    index = int(reverse_dict[label]) - 1
    return embeds[index]

# gen_dataset = datasets.ImageFolder(root='/mnt/2tb//hugging_face/gen_images/', transform=transforms_, target_transform=custom_label_transform)
gen_dataset = datasets.ImageFolder(root='/mnt/2tb//hugging_face/gen_images/', transform=transforms_)
# gen_dataloader = DataLoader(gen_dataset, batch_size=32, shuffle=True, num_workers=4)

dataset_size = len(gen_dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size

# Split the dataset
train_dataset, val_dataset = random_split(gen_dataset, [train_size, val_size])


# Function to split dataset into subsets for each class
def split_dataset_by_class(dataset):
    # Dictionary to hold indices for each class
    class_indices = {i: [] for i in range(20)}  # There are 20 classes

    # Iterate over the dataset and store indices for each class
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    # Create a subset for each class
    subsets = {class_label: Subset(dataset, indices) for class_label, indices in class_indices.items()}
    return subsets

# Split the train_set
train_class_subsets = split_dataset_by_class(train_dataset)
batch = 10
multi_class_number = 20
data_loader = []
for i in range(multi_class_number):
    data_loader.append(data.DataLoader(train_class_subsets[i], batch_size=batch, shuffle=True,num_workers=2))

test_loader = data.DataLoader(val_dataset, batch_size=batch, shuffle=True,num_workers=2)


# # Create DataLoader for training set
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
#
# # Create DataLoader for validation set
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=4)

# print(gen_dataset[0][0].shape)
# print(gen_dataset[0][1].shape)
#
# exit()
print("Data Loaded ...\n")
# exit()
# Set the custom download directory
# custom_dir = '/mnt/2tb//pre_trained_model/vit_imagenet/'
custom_dir ='/mnt/2tb//pre_trained_model/TinyImageNet-Transformers'
if not os.path.exists(custom_dir):
    os.makedirs(custom_dir)
torch.hub.set_dir(custom_dir)


#checking the availability of cuda devices
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Load the pre-trained model with specified weights
weights = Swin_V2_T_Weights.IMAGENET1K_V1
model = swin_v2_t(weights=weights)
# Modify the final layer to match Tiny ImageNet classes (200 classes)
# model.head = nn.Linear(model.head.in_features, 200)
# Define the loss function and optimizer

# Get the number of input features for the final head layer
num_features = model.head.in_features
# Define new layers to add after the final layer
new_layers = nn.Sequential(
    nn.Linear(num_features, 1000),  # Add a fully connected layer with 512 units
    nn.ReLU(),                     # Add ReLU activation
    nn.Linear(1000, 1000),
    nn.Tanh(),
    # nn.Dropout(0.5),               # Add dropout for regularization
    nn.Linear(1000, 768),           # Another fully connected layer
    # nn.ReLU(),                     # ReLU activation
    # nn.Linear(256, 200)            # Final layer matching Tiny ImageNet classes
)

# new_layers = nn.Sequential(
#     nn.Linear(num_features, 2000),  # Add a fully connected layer with 512 units
#     nn.ReLU(),                     # Add ReLU activation
#     nn.Dropout(0.5),               # Add dropout for regularization
#     nn.Linear(2000, 4000),           # Another fully connected layer
#     nn.ReLU(),                     # ReLU activation
#     nn.Linear(4000, 64*64),
#     # nn.Reshape()# Final layer matching Tiny ImageNet classes
# )

# Replace the final head with the new layers
model.head = new_layers

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(classifier.parameters(), lr=0.0001)

for param in model.parameters():
    param.requires_grad = False

# Unfreeze the newly added layers
for param in model.head.parameters():
    param.requires_grad = True
#
# checkpoint_path = '/mnt/2tb//pre_trained_model/fine_tuned_tiny_imagenet_swin_v2_t_another.pth'
# classifier.load_state_dict(torch.load(checkpoint_path))

model = model.to(device)
netG = model
torch.cuda.empty_cache()
# Training function

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Linear(768, 600),
            nn.ReLU(True),
            nn.Linear(600, 400),
            nn.ReLU(True),
            nn.Linear(400, 100),
            nn.ReLU(True),
            nn.Linear(100, 30),
            nn.ReLU(True),
            nn.Linear(30,1),
            nn.Sigmoid(),

        )
    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

netD = []
for i in range(multi_class_number):
    netD.append(Discriminator(1).cuda())
    netD[i].apply(weights_init)

print("create G and D models\n")
nz = 1024
# number of generator filters
ngf = 64
#number of discriminator filters
ndf = 64

lr = 0.0002
# criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss()
criterion = nn.BCELoss()
# optimizer = optim.Adam(netG.parameters(), lr=0.0001)
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

optimizerD = []
for i in range(multi_class_number):
    optimizerD.append(optim.Adam(netD[i].parameters(), lr=lr, betas=(0.5, 0.999)))

fixed_noise = torch.randn(128, 512, 1, 1, device=device)
real_label = 1
fake_label = 0

niter = 500
g_loss = []
d_loss = []
img_list = []
print("Starting Training Loop...")


for epoch in range(niter):
    for i, data_set in enumerate(zip(*data_loader), 0):
    # for i, data in enumerate(dataloader, 0):
        for j in range(multi_class_number):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD[j].zero_grad()
            real_cpu = data_set[target_pattern[j]][0].to(device)
            embed_label = embeds[target_pattern[j]]
            batch_size = real_cpu.size(0)
            # print(batch_size)
            # batch_size = 128
            label = torch.full((batch_size,), real_label, device=device,dtype=torch.float)
            embedding = embed_label.unsqueeze(0).repeat(batch_size, 1)
            embedding = embedding.to(torch.float32).to(device)
            # print("real_cpu")
            # print(real_cpu.shape)
            # print(embedding.shape)
            output = netD[j](embedding)
            # output = torch.squeeze(output)

            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()


            # print("=======")
            # train with fake
            # noise = torch.randn(batch_size, 512, 1, 1, device=device)
            # print(noise.shape)
            input_image = data_set[j][0].to(device)
            # print(input_image.shape)

            fake = netG(input_image)

            label.fill_(fake_label)
            # print(fake.shape)
            output = netD[j](fake.detach())
            # print(label.shape)
            # print("output")
            # print(output.shape)
            # print(label.shape)
            if (output.shape != label.shape):
                continue
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD[j].step()


            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD[j](fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            print('[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, niter, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # #save the output
            # if i % 100 == 0:
            #     print('saving the output')
            #     vutils.save_image(real_cpu,'./cifar_gan/output/1_real_samples.png',normalize=True)
            #     imgs = data_set[j][0][0].to(device)
            #     nis = torch.randn(1, int(nz/2), 1, 1, device=device)
            #     fake = netG(nis, imgs)
            #     vutils.save_image(fake.detach(),'./cifar_gan/output/1_fake_samples_epoch_%03d_%03d.png' % (epoch, j),normalize=True)

    if epoch % 100 == 0:
        for k in range(multi_class_number):
            img_list = []
            with torch.no_grad():
                images = data_set[k][0][0:8].to(device)
                # fixed_noise = torch.randn(8, 512, 1, 1).to(device)
                fake = netG(images).detach().cpu()
                vutils.save_image(fake, f'/home/gan/imagenet/g_images/baseline_0620_generated_image_epoch_{epoch}_{k}.png', normalize=True)
        # img_list.append(vutils.make_grid(fake, normalize=True))
        # for item in img_list:
        #     im = transforms.ToPILImage()(item)
        #     # print(k)
        # plt.imshow(im)
        # plt.savefig(f'./cifar_gan/output/1_torch_cifar_gan_generated_image_epoch_{epoch}_iter_{i}_{k}.png')

    # Check pointing for every epoch
    torch.save(netG.state_dict(), '/mnt/2tb/imagenet/g_model/baseline_G_0620_epoch_%d.pth' % (epoch))
    # torch.save(netD.state_dict(), './cifar_gan/weights/netD_epoch_%d.pth' % (epoch))

gloss_np = np.array(g_loss)
dloss_np = np.array(d_loss)

np.save('/mnt/2tb/imagenet/baseline_0620_gloss.npy', gloss_np)
np.save('/mnt/2tb/imagenet/baseline_0620_dloss.npy', dloss_np)


# torch.save(netG, './cifar_gan/torch_cifar_cGAN_G_0508')

for i in range(20):
    torch.save(netD[i], '/mnt/2tb/imagenet/d_model/baseline_D_0620_model_' + str(i))


#
# def train(model, loader, criterion, optimizer):
#     model.train()
#     running_loss = 0.0
#     correct = 0
#     total = 0
#
#     it = 0
#     for inputs, labels in loader:
#         it += 1
#
#         inputs, labels = inputs.to(device), labels.float().to(device)
#
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, labels)
#         # print(loss.dtype)
#         # print(inputs.dtype)
#         # print(labels.dtype)
#         loss.backward()
#         optimizer.step()
#
#         running_loss += loss.item()
#         # _, predicted = outputs.max(1)
#         # total += labels.size(0)
#         # correct += predicted.eq(labels).sum().item()
#         # if (it%100 == 0):
#         #     print(f"{it}/{len(train_loader)}")
#         # print(it)
#
#     # accuracy = 100. * correct / total
#     print(f'Training Loss: {running_loss/len(loader)}')
#
# # Validation function
# def validate(model, loader, criterion):
#     model.eval()
#     val_loss = 0.0
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for inputs, labels in loader:
#             # print(inputs.shape)
#             # print(labels.shape)
#             # print(labels)
#             # exit()
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             val_loss += loss.item()
#
#             # total += labels.size(0)
#             # correct += predicted.eq(labels).sum().item()
#
#     print(f'Validation Loss: {val_loss/len(loader)}')
#     return val_loss
#
#
# # Fine-tuning process
# num_epochs = 100
# loss_base = 100
# for epoch in range(num_epochs):
#     print(f'Epoch {epoch+1}/{num_epochs}')
#     train(model, train_loader, criterion, optimizer)
#     val_loss = validate(model, val_loader, criterion)
#     if (val_loss < loss_base):
#         torch.save(model.state_dict(), '/mnt/2tb//pre_trained_model/imagenet_image_to_embed_encoder.pth')
#         print("Saved\n")
#         loss_base = val_loss


# MODEL_PATH = '/mnt/2tb//pre_trained_model/fine_tuned_tiny_imagenet_convnext_encoder.pth'
# model.load_state_dict(torch.load(MODEL_PATH))
#
# validate_new(model, classifier, val_loader, val_loader_normal, criterion)