import pandas as pd
import numpy as np
import pickle
import time
import scipy
import matplotlib.pyplot as plt

import torchvision.transforms as transforms

import random

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision

import torch.nn.functional as F

import torchvision.models as models

from sklearn.neighbors import KNeighborsClassifier

from tqdm import tqdm


DEVICE = 'cpu'


def main():
    run()


def load_dataloaders(b_size=1, shuffle=False):
	train_transform = transforms.Compose(
		[transforms.ToTensor(),
		 transforms.Normalize((0.5,), (0.5,), (0.5,))])

	test_transform = transforms.Compose(
		[transforms.ToTensor(),
		 transforms.Normalize((0.5,), (0.5,), (0.5,))])

	train_set = torchvision.datasets.CIFAR10(
		root='./data/CIFAR10',
		train=True,
		download=True,
		transform=train_transform)

	train_loader = torch.utils.data.DataLoader(
		train_set,
		batch_size=b_size,
		shuffle=shuffle,
    num_workers=2)

	test_set = torchvision.datasets.CIFAR10(
		root='./data/CIFAR10',
		train=False,
		download=True,
		transform=test_transform)

	test_loader = torch.utils.data.DataLoader(
		test_set,
		batch_size=b_size,
		shuffle=False,
    num_workers=2)
 
	return train_loader, test_loader


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # input is Z, going into a convolution
        self.conv1 = nn.Conv2d(3, 8, kernel_size=5, stride=1, padding=2)
        self.bn1   = nn.BatchNorm2d(8)

        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2)
        self.bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm2d(32)

        self.conv4 = nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2)
        self.bn4 = nn.BatchNorm2d(64)

        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(128)

        self.conv6 = nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0)
        self.bn6 = nn.BatchNorm2d(128)
        
        self.avgpool = nn.AvgPool2d(8)
        self.linear = nn.Linear(128, 10)    
        self.relu = nn.ReLU()
        
    def forward(self, I):   
        x = self.relu(self.bn1(self.conv1(I)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.relu(self.bn5(self.conv5(x)))
        C = self.relu(self.bn6(self.conv6(x)))
        
        x = self.avgpool(C)
        x = x.view(x.shape[0], x.shape[1])
        logits = self.linear(x)
        return logits, x, C


def run():

    train_loader, test_loader = load_dataloaders(b_size=1, shuffle=False)
    netC = CNN()
    netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device(DEVICE)))
    netC = netC.eval()
    weights = netC.linear.weight

    X_train_c = list()
    X_test_c = list()
    X_train_x = list()
    X_test_x = list()
    X_train_y = list()
    X_test_y = list()
    X_train_C = list()
    X_test_C = list()


    ### Collect Twin Data and Feature Map Data

    #### Iterate just one image at a time
    for i, data in enumerate(tqdm(train_loader)):
        img, label = data
        img, label = img.to(DEVICE), label.to(DEVICE)
        logits, x, C = netC(img)
        y_hat = torch.argmax(logits).item()
        c = torch.mul(x[0], weights[y_hat])
        
        X_train_c.append(c.cpu().detach().numpy().tolist())
        X_train_C.append(C.detach().numpy().tolist())
        X_train_x.append(x[0].cpu().detach().numpy().tolist())
        X_train_y.append(y_hat)
        

    #### Iterate just one image at a time
    for i, data in enumerate(tqdm(test_loader)):
        img, label = data
        img, label = img.to(DEVICE), label.to(DEVICE)
        logits, x, C = netC(img)
        y_hat = torch.argmax(logits).item()
        c = torch.mul(x[0], weights[y_hat])
        
        X_test_c.append(c.cpu().detach().numpy().tolist())
        X_test_C.append(C.detach().numpy().tolist())
        X_test_x.append(x[0].cpu().detach().numpy().tolist())
        X_test_y.append(y_hat)
        

    X_train_c = np.array(X_train_c)
    X_test_c = np.array(X_test_c)
    X_train_C = np.array(X_train_C)
    X_test_C = np.array(X_test_C)
    X_train_x = np.array(X_train_x)
    X_test_x = np.array(X_test_x)
    X_train_y = np.array(X_train_y)
    X_test_y = np.array(X_test_y)


    np.save("data/X_train_cont.npy", X_train_c)
    np.save("data/X_test_cont.npy",  X_test_c)
    np.save("data/X_train_conv.npy", X_train_C)
    np.save("data/X_test_conv.npy",  X_test_C)
    np.save("data/X_train_x.npy",    X_train_x)
    np.save("data/X_test_x.npy",     X_test_x)
    np.save("data/X_train_y.npy",    X_train_y)
    np.save("data/X_test_y.npy",     X_test_y)

    print(X_train_c.shape)
    print(X_train_x.shape)
    print(X_train_C.shape)
    print(X_train_y.shape)


if __name__ == '__main__':
    main()


