import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageNet
import torch.utils.data
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Subset
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.models import Swin_V2_T_Weights, swin_v2_t
import os
from PIL import Image
from pytorch_msssim import ssim


import torchvision.models as models
from torchvision.models import ConvNeXt_Tiny_Weights

from sklearn.metrics import normalized_mutual_info_score
import warnings
import time

warnings.filterwarnings("ignore")



# # Paths to the new dataset
# dataset_path = "/mnt/2tb/data/tiny-imagenet-10-animals"
dataset_path = "/mnt/2tb/data/tiny-imagenet-200"
train_dir = os.path.join(dataset_path, "train")
test_dir = os.path.join(dataset_path, "test")
val_dir = os.path.join(dataset_path, 'val/images')
wnids_file = os.path.join(dataset_path, 'wnids.txt')
val_annotations_file = os.path.join(dataset_path, 'val/val_annotations.txt')

# with open(wnids_file, 'r') as f:
#     wnids = [line.strip() for line in f.readlines()]

# print(wnids)

# data_addr = f'/mnt/2tb/tiny_imgnet_gan/g_data/baseline_{date}_epoch_{epoch}_{k}.pt'

# Data transformations
transforms_ = transforms.Compose([
    # transforms.Resize((64, 64)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    # transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transforms_normal = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    # transforms.ToPILImage(),
    # transforms.Resize((224, 224)),  # Resize to 64x64 as Tiny ImageNet images are 64x64
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


# # Load datasets
train_set = datasets.ImageFolder(train_dir, transform=transforms_)
train_set_normal = datasets.ImageFolder(train_dir, transform=transforms_normal)
test_set = datasets.ImageFolder(test_dir, transform=transforms_)
# val_set = datasets.ImageFolder(val_dir, transform=transforms_)

batch = 128

train_loader = data.DataLoader(train_set, batch_size=batch, shuffle=True, num_workers=4)
train_loader_normal = data.DataLoader(train_set_normal, batch_size=batch, shuffle=True, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=batch, shuffle=False, num_workers=4)
# val_loader = data.DataLoader(val_set, batch_size=32, shuffle=False, num_workers=4)

# # Create validation dataset and dataloader
# val_dataset = TinyImageNetValDataset(val_dir=val_dir, annotations_file=val_annotations_file, transform=transforms_)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

print(len(train_loader))

# Get the class_to_idx mapping
class_to_idx = train_set.class_to_idx

# Create the reverse mapping
idx_to_class = {v: k for k, v in class_to_idx.items()}
class_to_idx = {k: v for k, v in class_to_idx.items()}
# print(idx_to_class['n09256479'])
# Print out the mappings
# for idx, wnid in idx_to_class.items():
#     print(f"Label {idx} corresponds to WordNet ID {wnid}")
class CustomDataset(Dataset):
    def __init__(self, annotation_file, img_dir, transform=None):
        self.img_labels = []
        self.img_dir = img_dir
        self.transform = transform

        # Read the annotation file and parse labels
        with open(annotation_file, 'r') as file:
            for line in file:
                parts = line.strip().split()
                img_name = parts[0]
                label = class_to_idx[parts[1]]  # Assuming label is a single integer
                self.img_labels.append((img_name, label))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# Create the dataset
val_dataset = CustomDataset(annotation_file=val_annotations_file, img_dir=val_dir, transform=transforms_)
val_dataset_normal = CustomDataset(annotation_file=val_annotations_file, img_dir=val_dir, transform=transforms_normal)

# Create the DataLoader
val_loader = data.DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader_normal = data.DataLoader(val_dataset_normal, batch_size=32, shuffle=True, num_workers=4)
# exit()

print("Data Loaded ...\n")


# Set the custom download directory
# custom_dir = '/mnt/2tb/pre_trained_model/vit_imagenet/'
custom_dir ='/mnt/2tb/pre_trained_model/TinyImageNet-Transformers'
if not os.path.exists(custom_dir):
    os.makedirs(custom_dir)
torch.hub.set_dir(custom_dir)


#checking the availability of cuda devices
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'




classifier = models.convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
# print(model.classifier)
original_classifier = classifier.classifier

# for param in model.parameters():
#     param.requires_grad = False

new_layers = nn.Sequential(
    original_classifier,
    nn.Linear(1000, 512),  # Add a fully connected layer with 512 units
    nn.ReLU(),                 # Add ReLU activation
    nn.Dropout(0.5),               # Add dropout for regularization
    nn.Linear(512, 256),           # Another fully connected layer
    nn.ReLU(),                     # ReLU activation
    nn.Linear(256, 200)            # Final layer matching Tiny ImageNet classes
)

classifier.classifier = new_layers

for param in classifier.parameters():
    param.requires_grad = False

# Unfreeze the newly added layers
for param in classifier.classifier.parameters():
    param.requires_grad = False

checkpoint_path = '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_convnext.pth'
classifier.load_state_dict(torch.load(checkpoint_path))


torch.cuda.empty_cache()

# # Load the pre-trained model with specified weights
# weights = Swin_V2_T_Weights.IMAGENET1K_V1
# classifier = swin_v2_t(weights=weights)
# # Modify the final layer to match Tiny ImageNet classes (200 classes)
# # model.head = nn.Linear(model.head.in_features, 200)
# # Define the loss function and optimizer
#
# # Get the number of input features for the final head layer
# num_features = classifier.head.in_features
# # Define new layers to add after the final layer
# new_layers = nn.Sequential(
#     nn.Linear(num_features, 512),  # Add a fully connected layer with 512 units
#     nn.ReLU(),                     # Add ReLU activation
#     nn.Dropout(0.5),               # Add dropout for regularization
#     nn.Linear(512, 256),           # Another fully connected layer
#     nn.ReLU(),                     # ReLU activation
#     nn.Linear(256, 200)            # Final layer matching Tiny ImageNet classes
# )
#
#
# # Replace the final head with the new layers
# classifier.head = new_layers
#
# # criterion = nn.CrossEntropyLoss()
# # optimizer = optim.Adam(classifier.parameters(), lr=0.0001)
#
# for param in classifier.parameters():
#     param.requires_grad = False
#
# # Unfreeze the newly added layers
# for param in classifier.head.parameters():
#     param.requires_grad = False
#
# checkpoint_path = '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_swin_v2_t_another.pth'
# classifier.load_state_dict(torch.load(checkpoint_path))
#
classifier = classifier.to(device)
#
# torch.cuda.empty_cache()
# # Training function


print("Loaded Classifier\n")

encoder = models.convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
# print(model.classifier)
original_classifier = encoder.classifier

new_layers1 = nn.Sequential(
    original_classifier,
    nn.Linear(1000, 2000),  # Add a fully connected layer with 512 units
    nn.ReLU(),                     # Add ReLU activation
    nn.Dropout(0.5),               # Add dropout for regularization
    nn.Linear(2000, 4000),           # Another fully connected layer
    nn.ReLU(),                     # ReLU activation
    nn.Linear(4000, 3*64*64),
)

encoder.classifier = new_layers1


for param in encoder.parameters():
    param.requires_grad = False

# Unfreeze the newly added layers
for param in encoder.classifier.parameters():
# for param in encoder.classifier.parameters():
    param.requires_grad = True


# weights = Swin_V2_T_Weights.IMAGENET1K_V1
# encoder = swin_v2_t(weights=weights)
#
# num_features = encoder.head.in_features
# # print(num_features)
# # new_layers1 = nn.Sequential(
# #     nn.Linear(num_features, 2000),  # Add a fully connected layer with 512 units
# #     nn.ReLU(),                     # Add ReLU activation
# #     nn.Dropout(0.5),               # Add dropout for regularization
# #     nn.Linear(2000, 4000),           # Another fully connected layer
# #     nn.ReLU(),
# #     nn.Dropout(0.5),
# #     nn.Linear(4000, 4000),           # Another fully connected layer
# #     nn.ReLU(),
# #     nn.Dropout(0.5),
# #     nn.Linear(4000, 4000),           # Another fully connected layer
# #     nn.ReLU(), # ReLU activation
# #     nn.Dropout(0.5),
# #     nn.Linear(4000, 3*64*64),
# #     # nn.Reshape()# Final layer matching Tiny ImageNet classes
# # )
# new_layers1 = nn.Sequential(
#     nn.Linear(num_features, 2000),  # Add a fully connected layer with 512 units
#     nn.ReLU(),                     # Add ReLU activation
#     nn.Dropout(0.5),               # Add dropout for regularization
#     nn.Linear(2000, 4000),           # Another fully connected layer
#     nn.ReLU(),                     # ReLU activation
#     nn.Linear(4000, 3*64*64),
#     # nn.Reshape()# Final layer matching Tiny ImageNet classes
# )
#
# encoder.head = new_layers1   # for convnext

#
# print(encoder)
# exit()

criterion = nn.CrossEntropyLoss()

optimizer_en = optim.Adam(encoder.parameters(), lr=0.0001)

# for param in encoder.parameters():
#     param.requires_grad = False
#
# # Unfreeze the newly added layers
# for param in encoder.head.parameters():
# # for param in encoder.classifier.parameters():
#     param.requires_grad = True


encoder = encoder.to(device)


print("create encoder\n")

def build_decoder():
    # Define the Decoder
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
            self.fc1 = nn.Linear(200, 1000)
            self.fc2 = nn.Linear(1000, 1000)
            self.fc3 = nn.Linear(1000, 200)
            self.dropoff = nn.Dropout(0.5)


        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = self.dropoff(x)
            x = torch.relu(self.fc2(x))
            # x = self.dropoff(x)
            # x = torch.relu(self.fc2(x))
            # x = torch.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    decoder = Decoder()
    return decoder

decoder = build_decoder()
decoder = decoder.to(device)


optimizer_de = optim.Adam(decoder.parameters(), lr=0.0001)

def ssim_loss(img1, img2):
    # return 1 - ssim(img1, img2, data_range=1, size_average=True)
    return ssim(img1, img2, data_range=1, size_average=True)


def train(encoder, classifier, loader, criterion):
    classifier.eval()
    encoder.train()
    running_loss = 0.0
    running_ssim_loss = 0.0
    correct = 0
    correct_fd = 0
    total = 0


    it = 0


    for (inputs, labels), (inputs_normal, labels_normal) in zip(loader, train_loader_normal):
        it += 1

        inputs, labels = inputs.to(device), labels.to(device)
        inputs_normal, labels_normal = inputs_normal.to(device), labels_normal.to(device)


        optimizer_en.zero_grad()
        optimizer_de.zero_grad()
        # t0 = time.time()
        ob_inputs = encoder(inputs)


        loss_ssim = ssim_loss(ob_inputs.view(-1, 3, 64, 64), inputs_normal)
        loss_ssim_square = loss_ssim * loss_ssim
        loss_encoder = 0.001 * loss_ssim_square
        loss_encoder.backward(retain_graph=True)

        # t1 = time.time()
        # print(t1-t0)
        # print(ob_inputs.shape)
        # t0 = time.time()
        ob_outputs = classifier(ob_inputs.view(-1, 3, 64, 64))
        # t1 = time.time()
        # print(t1-t0)

        # t0 = time.time()
        outputs = decoder(ob_outputs)
        # outputs = ob_outputs
        # t1 = time.time()
        # print(t1-t0)
        # print("=======================")

        classifier_outputs = classifier(inputs)     # for fidelity
        _, classifier_predicted = classifier_outputs.max(1)  # for fidelity



        # loss = criterion(outputs, labels)
        loss = criterion(outputs, classifier_predicted)   # for fidelity

        loss.backward()
        optimizer_de.step()
        optimizer_en.step()


        running_ssim_loss += loss_ssim_square.item()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        correct_fd += predicted.eq(classifier_predicted).sum().item()
        if it % 100 == 0:
            print(f'{it} : {running_loss/it} {running_ssim_loss/it}')
        # print(loss.item())

    accuracy = 100. * correct / total
    fidelity = 100. * correct_fd / total
    print(f'Training Loss: {running_loss/len(loader)}, SSIM Loss: {running_ssim_loss/len(loader)}, Accuracy: {accuracy}%, Fidelity: {fidelity}%')

# Validation function
def validate(encoder, classifier, loader, criterion):
    encoder.eval()
    classifier.eval()
    val_loss = 0.0
    correct = 0
    correct_fd = 0
    total = 0
    loss_ssim_all = 0.0
    with torch.no_grad():
        for (inputs, labels), (inputs_normal, labels_normal) in zip(loader, val_loader_normal):
            # print(inputs.shape)
            # print(labels.shape)
            # print(labels)
            # exit()
            inputs, labels = inputs.to(device), labels.to(device)
            inputs_normal, labels_normal = inputs_normal.to(device), labels_normal.to(device)

            t0 = time.time()
            ob_inputs = encoder(inputs)
            t1 = time.time()
            print(t1-t0)

            loss_ssim = ssim_loss(ob_inputs.view(-1, 3, 64, 64), inputs_normal)
            loss_ssim_square = loss_ssim * loss_ssim
            loss_ssim_all += loss_ssim_square.item()

            t0 = time.time()
            ob_outputs = classifier(ob_inputs.view(-1, 3, 64, 64))
            t1 = time.time()
            print(t1-t0)

            t0 = time.time()
            outputs = decoder(ob_outputs)
            t1 = time.time()
            print(t1-t0)
            print("================================")
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)

            classifier_outputs = classifier(inputs)
            _, classifier_predicted = classifier_outputs.max(1)


            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            correct_fd += predicted.eq(classifier_predicted).sum().item()

    accuracy = 100. * correct / total
    fidelity = 100. * correct_fd / total
    print(f'Validation Loss: {val_loss/len(loader)}, SSIM Loss: {loss_ssim_all/len(loader)}, Accuracy: {accuracy}%, Fidelity: {fidelity}%')
    return accuracy

def validate_new(encoder, classifier, loader, loader_normal, criterion):
    encoder.eval()
    classifier.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for (inputs, labels), (inputs_n, labels_n) in zip(loader, loader_normal):
            # print(inputs.shape)
            # print(labels.shape)
            # print(labels)
            # exit()

            inputs, labels = inputs.to(device), labels.to(device)

            ob_inputs = encoder(inputs)

            # random_input = torch.randn_like(inputs)
            # print(ob_inputs.shape)
            outputs = classifier(ob_inputs.view(-1, 3, 64, 64))
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            # print(predicted.shape)
            total += labels.size(0)
            # correct += predicted.eq(labels).sum().item()

            cla_outputs = classifier(inputs)
            _, cla_pred = cla_outputs.max(1)
            # print(cla_pred.shape)
            correct += (predicted == cla_pred).sum().item()

            # inputs_np = inputs_n.flatten().detach().cpu().numpy()
            # rand_inp =  torch.randn_like(inputs_n)
            # rand_inp_np = rand_inp.flatten().detach().cpu().numpy()
            # ob_inputs_np = ob_inputs.flatten().detach().cpu().numpy()
            #
            # print(rand_inp.shape)
            # mi1 = normalized_mutual_info_score(inputs_np[0], rand_inp_np[0])
            # mi2 = normalized_mutual_info_score(inputs_np, ob_inputs_np)
            #
            # print(f'{mi1} -> {mi2}')

    accuracy = 100. * correct / total
    print(f'Validation Loss: {val_loss/len(loader)}, Accuracy: {accuracy}%')
    return accuracy


# Fine-tuning process
num_epochs = 100
acc_base = 0
validate(encoder, classifier, val_loader, criterion)
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train(encoder, classifier, train_loader, criterion)
    acc = validate(encoder, classifier, val_loader, criterion)

    if (acc > acc_base):
        torch.save(encoder.state_dict(), '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_convnext_encoder_ssim.pth')
        torch.save(decoder.state_dict(), '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_convnext_decoder_ssim.pth')
        print("Saved\n")
        acc_base = acc
#
MODEL_PATH = '/mnt/2tb/pre_trained_model/fine_tuned_tiny_imagenet_convnext_encoder_ssim.pth'
encoder.load_state_dict(torch.load(MODEL_PATH))
validate(encoder, classifier, val_loader, criterion)
#
# validate_new(model, classifier, val_loader, val_loader_normal, criterion)
