import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageNet
import torch.utils.data
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Subset
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.models import Swin_V2_T_Weights, swin_v2_t
import os
from PIL import Image
from pytorch_msssim import ssim


import torchvision.models as models
from torchvision.models import ConvNeXt_Tiny_Weights

from sklearn.metrics import normalized_mutual_info_score
import warnings
import time
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")



# # Paths to the new dataset
# dataset_path = "/mnt/2tb/data/tiny-imagenet-10-animals"
dataset_path = "/mnt/2tb/data/tiny-imagenet-200"
train_dir = os.path.join(dataset_path, "train")
test_dir = os.path.join(dataset_path, "test")
val_dir = os.path.join(dataset_path, 'val/images')
wnids_file = os.path.join(dataset_path, 'wnids.txt')
val_annotations_file = os.path.join(dataset_path, 'val/val_annotations.txt')

# with open(wnids_file, 'r') as f:
#     wnids = [line.strip() for line in f.readlines()]

# print(wnids)

# data_addr = f'/mnt/2tb/tiny_imgnet_gan/g_data/baseline_{date}_epoch_{epoch}_{k}.pt'

# Data transformations
transforms_ = transforms.Compose([
    # transforms.Resize((64, 64)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    # transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transforms_normal = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    # transforms.ToPILImage(),
    # transforms.Resize((224, 224)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


# # Load datasets
train_set = datasets.ImageFolder(train_dir, transform=transforms_)
train_set_normal = datasets.ImageFolder(train_dir, transform=transforms_normal)
test_set = datasets.ImageFolder(test_dir, transform=transforms_)
# val_set = datasets.ImageFolder(val_dir, transform=transforms_)

batch = 128

train_loader = data.DataLoader(train_set, batch_size=batch, shuffle=True, num_workers=4)
train_loader_normal = data.DataLoader(train_set_normal, batch_size=batch, shuffle=True, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=batch, shuffle=False, num_workers=4)
# val_loader = data.DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4)

# # Create validation dataset and dataloader
# val_dataset = TinyImageNetValDataset(val_dir=val_dir, annotations_file=val_annotations_file, transform=transforms_)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

print(len(train_loader))

# Get the class_to_idx mapping
class_to_idx = train_set.class_to_idx

# Create the reverse mapping
idx_to_class = {v: k for k, v in class_to_idx.items()}
class_to_idx = {k: v for k, v in class_to_idx.items()}
# print(idx_to_class['n09256479'])
# Print out the mappings
# for idx, wnid in idx_to_class.items():
#     print(f"Label {idx} corresponds to WordNet ID {wnid}")
class CustomDataset(Dataset):
    def __init__(self, annotation_file, img_dir, transform=None):
        self.img_labels = []
        self.img_dir = img_dir
        self.transform = transform

        # Read the annotation file and parse labels
        with open(annotation_file, 'r') as file:
            for line in file:
                parts = line.strip().split()
                img_name = parts[0]
                label = class_to_idx[parts[1]]  # Assuming label is a single integer
                self.img_labels.append((img_name, label))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# Create the dataset
val_dataset = CustomDataset(annotation_file=val_annotations_file, img_dir=val_dir, transform=transforms_)
val_dataset_normal = CustomDataset(annotation_file=val_annotations_file, img_dir=val_dir, transform=transforms_normal)

# Create the DataLoader
val_loader = data.DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader_normal = data.DataLoader(val_dataset_normal, batch_size=32, shuffle=True, num_workers=4)
# exit()

print("Data Loaded ...\n")


# Set the custom download directory
# custom_dir = '/mnt/2tb/pre_trained_model/vit_imagenet/'
custom_dir ='/mnt/2tb/pre_trained_model/TinyImageNet-Transformers'
if not os.path.exists(custom_dir):
    os.makedirs(custom_dir)
torch.hub.set_dir(custom_dir)


#checking the availability of cuda devices
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Load the pre-trained model with specified weights
weights = Swin_V2_T_Weights.IMAGENET1K_V1
classifier = swin_v2_t(weights=weights)
# Modify the final layer to match Tiny ImageNet classes (200 classes)
# model.head = nn.Linear(model.head.in_features, 200)
# Define the loss function and optimizer

# Get the number of input features for the final head layer
num_features = classifier.head.in_features
# Define new layers to add after the final layer
new_layers = nn.Sequential(
    nn.Linear(num_features, 512),  # Add a fully connected layer with 512 units
    nn.ReLU(),                     # Add ReLU activation
    nn.Dropout(0.5),               # Add dropout for regularization
    nn.Linear(512, 256),           # Another fully connected layer
    nn.ReLU(),                     # ReLU activation
    nn.Linear(256, 200)            # Final layer matching Tiny ImageNet classes
)


# Replace the final head with the new layers
classifier.head = new_layers

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(classifier.parameters(), lr=0.0001)

for param in classifier.parameters():
    param.requires_grad = False

# Unfreeze the newly added layers
for param in classifier.head.parameters():
    param.requires_grad = False

checkpoint_path = '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_swin_v2_t_another.pth'
classifier.load_state_dict(torch.load(checkpoint_path))

classifier = classifier.to(device)

torch.cuda.empty_cache()
# Training function


print("Loaded Classifier\n")


encoder = models.convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
# print(model.classifier)
original_classifier = encoder.classifier

new_layers1 = nn.Sequential(
    original_classifier,
    nn.Linear(1000, 2000),  # Add a fully connected layer with 512 units
    nn.ReLU(),                     # Add ReLU activation
    nn.Dropout(0.5),               # Add dropout for regularization
    nn.Linear(2000, 4000),           # Another fully connected layer
    nn.ReLU(),                     # ReLU activation
    nn.Linear(4000, 3*64*64),
)

encoder.classifier = new_layers1


for param in encoder.parameters():
    param.requires_grad = False

# Unfreeze the newly added layers
for param in encoder.classifier.parameters():
# for param in encoder.classifier.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()

optimizer_en = optim.Adam(encoder.parameters(), lr=0.0001)



encoder = encoder.to(device)


print("create encoder\n")

def build_decoder():
    # Define the Decoder
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
            self.fc1 = nn.Linear(200, 1000)
            self.fc2 = nn.Linear(1000, 1000)
            self.fc3 = nn.Linear(1000, 200)
            self.dropoff = nn.Dropout(0.5)


        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = self.dropoff(x)
            x = torch.relu(self.fc2(x))
            # x = self.dropoff(x)
            # x = torch.relu(self.fc2(x))
            # x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    decoder = Decoder()
    return decoder

decoder = build_decoder()
decoder = decoder.to(device)


optimizer_de = optim.Adam(decoder.parameters(), lr=0.0001)

def ssim_loss(img1, img2):
    # return 1 - ssim(img1, img2, data_range=1, size_average=True)
    return ssim(img1, img2, data_range=1, size_average=True)


def train(encoder, classifier, loader, criterion):
    classifier.eval()
    encoder.train()
    running_loss = 0.0
    running_ssim_loss = 0.0
    correct = 0
    correct_fd = 0
    total = 0


    it = 0


    for (inputs, labels), (inputs_normal, labels_normal) in zip(loader, train_loader_normal):
        it += 1

        inputs, labels = inputs.to(device), labels.to(device)
        inputs_normal, labels_normal = inputs_normal.to(device), labels_normal.to(device)


        optimizer_en.zero_grad()
        optimizer_de.zero_grad()

        # t0 = time.time()
        ob_inputs = encoder(inputs)
        # t1 = time.time()
        # print(t1-t0)

        loss_ssim = ssim_loss(ob_inputs.view(-1, 3, 64, 64), inputs_normal)
        loss_ssim_square = loss_ssim * loss_ssim
        loss_encoder = 0.001 * loss_ssim_square
        loss_encoder.backward(retain_graph=True)


        # print(ob_inputs.shape)
        # t0 = time.time()
        ob_outputs = classifier(ob_inputs.view(-1, 3, 64, 64))
        # t1 = time.time()
        # print(t1-t0)

        # t0 = time.time()
        outputs = decoder(ob_outputs)
        # outputs = ob_outputs
        # t1 = time.time()
        # print(t1-t0)
        # print("=======================")

        classifier_outputs = classifier(inputs)     # for fidelity
        _, classifier_predicted = classifier_outputs.max(1)  # for fidelity



        # loss = criterion(outputs, labels)
        loss = criterion(outputs, classifier_predicted)   # for fidelity

        loss.backward()
        optimizer_de.step()
        optimizer_en.step()


        running_ssim_loss += loss_ssim_square.item()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        correct_fd += predicted.eq(classifier_predicted).sum().item()
        if it % 10 == 0:
            print(f'{it} : {running_loss/it} {running_ssim_loss/it}')
        # print(loss.item())

    accuracy = 100. * correct / total
    fidelity = 100. * correct_fd / total
    print(f'Training Loss: {running_loss/len(loader)}, SSIM Loss: {running_ssim_loss/len(loader)}, Accuracy: {accuracy}%, Fidelity: {fidelity}%')

# Validation function
def validate(encoder, classifier, loader, criterion):
    encoder.eval()
    classifier.eval()
    val_loss = 0.0
    correct = 0
    correct_fd = 0
    total = 0
    loss_ssim_all = 0.0
    with torch.no_grad():
        for (inputs, labels), (inputs_normal, labels_normal) in zip(loader, val_loader_normal):
            # print(inputs.shape)
            # print(labels.shape)
            # print(labels)
            # exit()
            inputs, labels = inputs.to(device), labels.to(device)
            inputs_normal, labels_normal = inputs_normal.to(device), labels_normal.to(device)
            ob_inputs = encoder(inputs)

            loss_ssim = ssim_loss(ob_inputs.view(-1, 3, 64, 64), inputs_normal)
            loss_ssim_square = loss_ssim * loss_ssim
            loss_ssim_all += loss_ssim_square.item()

            ob_outputs = classifier(ob_inputs.view(-1, 3, 64, 64))
            outputs = decoder(ob_outputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)

            classifier_outputs = classifier(inputs)
            _, classifier_predicted = classifier_outputs.max(1)


            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            correct_fd += predicted.eq(classifier_predicted).sum().item()

    accuracy = 100. * correct / total
    fidelity = 100. * correct_fd / total
    print(f'Validation Loss: {val_loss/len(loader)}, SSIM Loss: {loss_ssim_all/len(loader)}, Accuracy: {accuracy}%, Fidelity: {fidelity}%')
    return accuracy

def validate_new(encoder, classifier, loader, loader_normal, criterion):
    encoder.eval()
    classifier.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for (inputs, labels), (inputs_n, labels_n) in zip(loader, loader_normal):
            # print(inputs.shape)
            # print(labels.shape)
            # print(labels)
            # exit()

            inputs, labels = inputs.to(device), labels.to(device)

            ob_inputs = encoder(inputs)

            # random_input = torch.randn_like(inputs)
            # print(ob_inputs.shape)
            outputs = classifier(ob_inputs.view(-1, 3, 64, 64))
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            # print(predicted.shape)
            total += labels.size(0)
            # correct += predicted.eq(labels).sum().item()

            cla_outputs = classifier(inputs)
            _, cla_pred = cla_outputs.max(1)
            # print(cla_pred.shape)
            correct += (predicted == cla_pred).sum().item()

            # inputs_np = inputs_n.flatten().detach().cpu().numpy()
            # rand_inp =  torch.randn_like(inputs_n)
            # rand_inp_np = rand_inp.flatten().detach().cpu().numpy()
            # ob_inputs_np = ob_inputs.flatten().detach().cpu().numpy()
            #
            # print(rand_inp.shape)
            # mi1 = normalized_mutual_info_score(inputs_np[0], rand_inp_np[0])
            # mi2 = normalized_mutual_info_score(inputs_np, ob_inputs_np)
            #
            # print(f'{mi1} -> {mi2}')

    accuracy = 100. * correct / total
    print(f'Validation Loss: {val_loss/len(loader)}, Accuracy: {accuracy}%')
    return accuracy

def test_generation():
    # Test the model
    classifier.eval()
    correct_acc = 0
    correct_fd = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move data to GPU

            ob_data = encoder(data).detach().cpu()

            # Parameters
            num_samples = 10  # Number of images to sample and plot

            # Sampling indices
            # sample_indices = random.sample(range(ob_data.size(0)), num_samples)
            sample_indices = [0, 1, 2 ,3 ,4, 5, 6, 7, 8, 9]
            # Creating subplots
            fig, axes = plt.subplots(1, num_samples, figsize=(num_samples * 2, 2))

            for i, idx in enumerate(sample_indices):
                img = ob_data[idx].view(3, 64, 64)
                print(img.shape)
                img = img.permute(1, 2, 0)  # Change shape from (3, 32, 32) to (32, 32, 3)
                img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]

                axes[i].imshow(img)
                axes[i].axis('off')
                # axes[i].set_title(f'Index: {idx}')

            # Save the plot
            plt.tight_layout()
            plt.savefig('./sampled_images.png')
            plt.savefig('./sampled_images.pdf')
            plt.show()
            exit()



# Fine-tuning process
num_epochs = 100
acc_base = 0
validate(encoder, classifier, val_loader, criterion)
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train(encoder, classifier, train_loader, criterion)
    acc = validate(encoder, classifier, val_loader, criterion)

    if (acc > acc_base):
        torch.save(encoder.state_dict(), '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_swin_encoder_ssim.pth')
        torch.save(decoder.state_dict(), '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_swin_decoder_ssim.pth')
        print("Saved\n")
        acc_base = acc

# MODEL_PATH = '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_swin_encoder_ssim.pth'
# encoder.load_state_dict(torch.load(MODEL_PATH))

# MODEL_PATH = '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_swin_decoder_ssim.pth'
# decoder.load_state_dict(torch.load(MODEL_PATH))
# acc = validate(encoder, classifier, val_loader, criterion)
# test_generation()
# # validate_new(model, classifier, val_loader, val_loader_normal, criterion)
