import argparse
import os
import shutil
from datetime import datetime

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import numpy as np
import pandas as pd

import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

from sys import argv 

import sys
 
sys.path.append('../')
import models as models


import torchvision
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch
import os
import numpy as np
import torch.nn.functional as F
import time


model_name=argv[1] 


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

path_data="../data/"
trainset = torchvision.datasets.CIFAR10(
                root=path_data, train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
            root=path_data, train=False, download=True, transform=transform_test)


trainloader = torch.utils.data.DataLoader(
                trainset, batch_size=100, shuffle=False, num_workers=8)

val_loader = torch.utils.data.DataLoader(
            testset, batch_size=100, shuffle=False, num_workers=8)  


np.random.seed(0)


start = time.time()




total_models={}

total_models["resnet18_under_d16_bn"]=508
total_models["vgg16_under_d16_bn"]=502
total_models["resnet18_under_d8_bn"]=256
total_models["vgg16_under_d8_bn"]=256
total_models["resnet18_under_d4_bn"]=64
total_models["vgg16_under_d4_bn"]=64
total_models["resnet18_under_d2_bn"]=16
total_models["vgg16_under_d2_bn"]=16
total_models["vgg16_bn"]=4
total_models["resnet18_bn"]=4

model_dir="../results"

start_seed=30 if "under" in model_name else 0


for concat_number in [total_models[model_name]]:
    #for se in range(start_seed+1, start_seed+1+ concat_number):
    for se in range(500+1, 500+start_seed+1+ concat_number):
        paths=[]
        modelos=[]
        paths.append(model_dir + "/"+model_name+"/"+"result_random_"+str(se))
        print(paths[-1])
            
        for path in paths:
            modelos.append(torch.load(path,map_location=torch.device('cpu')))


        X_training_cs=[]
        X_test_cs=[]


        while(len(modelos)):
            modelo=modelos.pop()
	
            X_training=0
            X_test=0
            target_training=0
            target_test=0

            X_trainings=[]
            target_trainings=[]
            X_tests=[]
            target_tests=[]

            modelo.eval()
            with torch.no_grad():
                for batch_idx, (inputs, targets) in enumerate(trainloader):
                    inputs, targets = inputs.to(device), targets.to(device)
              
                    X_trainings.append(modelo.get_features(inputs).detach())
                    target_trainings.append(targets)

            X_training=torch.cat(X_trainings,dim=0)
            target_training= torch.cat(target_trainings,dim=0)
        
            modelo.eval()
            with torch.no_grad():
               for batch_idx, (inputs, targets) in enumerate(val_loader):
                   inputs, targets = inputs.to(device), targets.to(device)
                   X_tests.append(modelo.get_features(inputs).detach())
                   target_tests.append(targets)
        
            X_test=torch.cat(X_tests,dim=0)
            target_test= torch.cat(target_tests,dim=0)
              
            del modelo
        
            X_training_cs.append(X_training)
            X_test_cs.append(X_test)

    
        X_test_c=torch.cat(X_test_cs, dim=1)
        X_training_c=torch.cat(X_training_cs, dim=1)

        X_test_c=X_test_c.cpu().detach().numpy() 
        X_training_c=X_training_c.detach().numpy() 
        target_test=target_test.cpu().detach().numpy().reshape(-1,1) 
        target_training=target_training.detach().reshape(-1,1) 


        print(np.shape(X_test))
        print(np.shape(X_training))
        print(np.shape(target_test)) 
        print(np.shape(target_training))
        
		
        training_data=np.concatenate((target_training, X_training_c), axis=1)
        test_data=np.concatenate((target_test, X_test_c), axis=1)
               
        #write validation
        curr_dir=os.getcwd()+"/random_features/"
        curr_dir=curr_dir+model_name+"/"+str(se)
    
        if not os.path.exists(curr_dir): 
            os.makedirs(curr_dir)
        
        pd.DataFrame(training_data).to_csv(curr_dir + "/" + 'training.csv',header=False, index=False)
        pd.DataFrame(test_data).to_csv(curr_dir + "/" + 'test.csv',header=False, index=False)

        df=pd.read_csv(curr_dir + "/" + 'test.csv',header=None,index_col=False)
        print(df.to_numpy())

        print("TIME:::::::::::::")
        print(time.time() - start)
        print("::::::::::::::::::::::")

   
    
    

       

