import numpy as np
# import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.manifold import TSNE
from scipy import linalg



#Pytorch imports

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
from tqdm import tqdm
from glob import glob
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn.functional as F


# model = torchvision.models.resnet50(progress=False, weights='IMAGENET1K_V1')
# model.fc = nn.Sequential(
#   nn.Linear(in_features=2048, out_features=4000, bias=True),
#   nn.Sigmoid()
# )

# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.

# # model.load_state_dict(torch.load("rscnn_resnet50_imagenet.pt", map_location=device))


# model= nn.DataParallel(model,device_ids = [1, 2, 3])
# model.to(device)

# # torch.save(model, '/home/shireen/RSCNN/ImageNet/model_img.pt')
# torch.save(model.cpu().state_dict(), '/home/shireen/RSCNN/ImageNet/model_img.pt')

transform = transform_steps = transforms.Compose([
            transforms.Resize((224,224), interpolation=transforms.InterpolationMode.NEAREST),
            # transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),#, (0.2023, 0.1994, 0.2010)),
        ])


model = torchvision.models.resnet50(progress=False, weights='IMAGENET1K_V2')
# model.fc = nn.Sequential(
#   # nn.Linear(in_features=2048, out_features=2800, bias=True),
#   # nn.ReLU(),
#   # nn.Linear(in_features=2800, out_features=3000, bias=True),
#   # nn.ReLU(),
#   nn.Linear(in_features=2048, out_features=4000, bias=True),
#   nn.Sigmoid()
# )
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.
# model= nn.DataParallel(model,device_ids = [1, 2, 3])
# checkpoint = torch.load('model_img_v2.pth', map_location=device)
# model.load_state_dict(checkpoint["model"])
# model.load_state_dict(torch.load('/home/shireen/RSCNN/ImageNet/model_img_dont_delete_p_f.pt', map_location=device))
# model = torch.load('/home/shireen/RSCNN/ImageNet/model_img.pt')

# model = torchvision.models.resnet50(progress=False, weights='IMAGENET1K_V1')
# model.fc = torch.nn.Identity()
# model = nn.Sequential(
#   model,
#   nn.Linear(in_features=2048, out_features=4000, bias=True),
#   nn.Sigmoid()
# )

# model.load_state_dict(torch.load("rscnn_resnet50_imagenet.pt", map_location=device))

model.to(device)
# model= nn.DataParallel(model,device_ids = [1, 2, 3])

npzfile = np.load("/mnt/eris-alpha/mubashar/shared_data/imagenet_o_data.npz")
x_test, y_test = npzfile["x"], npzfile["y"]
print(x_test.shape, y_test.shape)

class ImageNetTest(Dataset):
    def __init__(self, x_test, y_test):
        self.x_test = x_test
        self.y_test = y_test
        self.transform = transform_steps = transforms.Compose([
            transforms.Resize((224,224), interpolation=transforms.InterpolationMode.NEAREST),
            # transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),#, (0.2023, 0.1994, 0.2010)),
        ])
    def __len__(self):
        return len(self.x_test)

    def __getitem__(self, idx):
        img = torch.from_numpy(np.transpose(self.x_test[idx], (2,0,1))/255.0).float()
        img = self.transform(img)

        return img, self.y_test[idx]
      
# x_test = transform(torch.from_numpy(x_test)/255.0).float()

# class ImageNetTest(Dataset):
#     def __init__(self, x_test):
#         self.x_test = x_test
#         self.transform = transform_steps = transforms.Compose([
#             transforms.Resize((224,224), interpolation=transforms.InterpolationMode.NEAREST),
#             # transforms.ToTensor(),
#             transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),#, (0.2023, 0.1994, 0.2010)),
#         ])
#     def __len__(self):
#         return len(self.x_test)

#     def __getitem__(self, idx):
#         img = torch.from_numpy(np.transpose(self.x_test[idx], (2,0,1))/255.0).float()
#         img = self.transform(img)

#         return img

test_dataset = ImageNetTest(x_test, y_test)
# test_dataset = ImageNetTest(x_test)
test_dataloader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
model.eval()

test_preds = []
test_labels = []
test_data = []
with torch.no_grad():
    for x,y in tqdm(test_dataloader):
        x =x.to(device)
        out = torch.softmax(model(x), dim=-1)
        test_preds.append(out.cpu())
        test_labels.append(y)
        test_data.append(x.cpu())
        # break

test_data = torch.cat(test_data).numpy()
test_labels = torch.cat(test_labels).cpu().numpy()
test_preds = torch.cat(test_preds).numpy()
print(test_labels)
print(np.sum(test_preds))
# accuracy = np.sum(np.argmax(test_preds, axis=-1) == test_labels)*100 / len(test_labels)
# print(accuracy)
np.save("/home/shireen/RSCNN/ImageNet/test_preds_cnn_o.npy", test_preds)

# from sklearn.metrics import top_k_accuracy_score
# acc_5 = top_k_accuracy_score(y_test, test_preds, k=5, normalize=True, labels=np.arange(1000))
# print(f"Top 5 Test Accuracy: {acc_5*100}%")