

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.neighbors import KNeighborsClassifier

import time

import sys

import numpy as np

from sklearn.cluster import SpectralClustering
from sklearn.metrics.cluster import adjusted_rand_score, rand_score
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score

from helper import *
from CIFAR10 import *

'''
Image Acc: 0.2836
Embedding Acc: 0.2945
'''

import os

# Get the absolute path of the current directory
current_directory = os.path.abspath(os.path.dirname(__file__))

# Set the TORCH_HOME environment variable to the current directory
os.environ['TORCH_HOME'] = current_directory


class AttentionScoreBlock(nn.Module):
    def __init__(self, input_size, output_size):
        super(AttentionScoreBlock, self).__init__()
        self.input_size = input_size
        self.output_size = output_size

    def forward(self, x):
        x_normalized = normalize(x)

        # Compute the attention weights using the normalized ReLU function
        attn_weights = torch.matmul(x_normalized, x_normalized.transpose(-2, -1))
        #attn_weights = normalize(attn_weights)
        attn_weights = softmax_without_diagonal_exp(attn_weights, dim=1)

        return attn_weights

# Define your custom model class
class CustomResNetWithAttention(nn.Module):
    def __init__(self):
        super(CustomResNetWithAttention, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        num_ftrs = self.resnet18.fc.in_features #512
        self.resnet18.fc = nn.Flatten()

        self.attention_block = AttentionScoreBlock(num_ftrs, num_ftrs)


    def forward(self, x):
        embedding = self.resnet18(x)
        affinity = self.attention_block(embedding)
        return embedding, affinity


def main():


    batch_size = 128
    num_epochs = 200

    learning_rate = [1e-7, 1e-5, 1e-3, 1e-1, 1]
    exp_index = sys.argv[1]
    lr = learning_rate[int(exp_index)]

    print(f'Learning Rate: {lr}')

    train_loader, val_loader, test_loader = CIFAR10(batch_size)

    # Load pretrained ResNet-18 model
    model = CustomResNetWithAttention()
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters: {total_params}")

    # Define loss function and optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Train the model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)


    # Define your learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.1, verbose=True)
    best_val_loss = float('inf')


    total_training_time = 0

    for epoch in range(num_epochs):

        start_time = time.time()

        model.train()
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs = inputs.to(device)

            adj_matrix = create_adjacency_matrix(labels)
            adj_matrix = normalize_with_diagonal_zero(adj_matrix).to(device)

            optimizer.zero_grad()

            embedding, affinity = model(inputs)

            loss = F.kl_div(affinity.log(), adj_matrix, reduction='batchmean')
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'[{epoch}] loss: {running_loss / 200:.3f}')
        running_loss = 0.0

        total_training_time += time.time() - start_time

        model.eval()
        with torch.no_grad():
            val_loss = 0
            for i, data in enumerate(val_loader, 0):
                inputs, labels = data
                inputs = inputs.to(device)
                embedding, affinity = model(inputs)

                adj_matrix = create_adjacency_matrix(labels)
                adj_matrix = normalize_with_diagonal_zero(adj_matrix).to(device)
                val_loss += F.kl_div(affinity.log(), adj_matrix, reduction='batchmean')


        # Check if the validation loss has improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), exp_index + '_best_model.pt')
            print('Best One')

        # Check if the validation loss has not improved for a certain number of epochs
        if epoch - scheduler.last_epoch > scheduler.patience:
            # Load the best model
            model.load_state_dict(torch.load(exp_index + '_best_model.pt'))

            # Reduce the step size of the scheduler
            for param_group in optimizer.param_groups:
                param_group['lr'] *= scheduler.factor
                scheduler.last_epoch = epoch

            # Print a message to indicate that the step size has been reduced
            print(f'Reducing step size to {optimizer.param_groups[0]["lr"]:.5f}')



    print(f'Time (min): {total_training_time // 60}')
    model.load_state_dict(torch.load(exp_index + '_best_model.pt'))

    model.eval()
    with torch.no_grad():
        all_images, all_labels, all_embeddings = extract_feature(model, train_loader, device)
        test_all_images, test_all_labels, test_all_embeddings = extract_feature(model, test_loader, device)
        all_images = all_images.reshape(all_images.shape[0], -1)
        test_all_images = test_all_images.reshape(test_all_images.shape[0], -1)

    classifier = LinearSVC()
    classifier.fit(all_embeddings, all_labels)
    y_pred = classifier.predict(test_all_embeddings)
    accuracy = accuracy_score(test_all_labels, y_pred)
    print('Embedding Acc:',accuracy)


    for n in [1,5,20,50,100,200,500,1000]:
        knn_classifier = KNeighborsClassifier(n_neighbors=n)
        knn_classifier.fit(all_images, all_labels)
        y_pred = knn_classifier.predict(test_all_images)
        accuracy = accuracy_score(test_all_labels, y_pred)
        print(f'KNN with k = {n}, Image Acc: {accuracy}')

        knn_classifier = KNeighborsClassifier(n_neighbors=n)
        knn_classifier.fit(all_embeddings, all_labels)
        y_pred = knn_classifier.predict(test_all_embeddings)
        accuracy = accuracy_score(test_all_labels, y_pred)
        print(f'KNN with k = {n}, Embedding Acc: {accuracy}')





if __name__ == '__main__':
    main()
