import argparse
import os
import shutil
from datetime import datetime

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import numpy as np
import train_util
import data_util
#import eval_loss_util
from sklearn.model_selection import train_test_split


from sklearn.linear_model import RidgeCV
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn import metrics


import models as models

import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'

from sys import argv 


import torchvision
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch
import os
import numpy as np
import torch.nn.functional as F

import time

first=argv[1] 
second=argv[2]
under_model=argv[3]
over_model=argv[4]

concat_number1=int(argv[5])

aug_on=argv[6]


if under_model=="vgg16_under_d2_bn" or under_model=="resnet18_under_d2_bn":
    if concat_number>10:
        quit()

if under_model=="vgg16_under_d4_bn" or under_model=="resnet18_under_d4_bn":
    if concat_number>16:
        quit()

if under_model=="vgg16_under_d8_bn" or under_model=="resnet18_under_d8_bn":
    if concat_number>64:
        quit()

if concat_number>16 and (first=="to" or first=="ro"):
    quit()


transform_train=0

print(aug_on)

if aug_on=="True":
    print("aug on")
    transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

else:
    transform_train = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Resize
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

path_data="./data/"
trainset = torchvision.datasets.CIFAR10(
                root=path_data, train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
            root=path_data, train=False, download=True, transform=transform_test)


trainloader = torch.utils.data.DataLoader(
                trainset, batch_size=100, shuffle=False, num_workers=4)

val_loader = torch.utils.data.DataLoader(
            testset, batch_size=100, shuffle=False, num_workers=4)  


np.random.seed(0)


start = time.time()

if (first=="to" and second=="to") or (first=="ro" and second=="ro"):
    if not (under_model=="resnet18_under_d16_bn" or under_model=="vgg16_under_d16_bn"):
         print("finised(to-to)")
         quit()    

print(first+second)

rando=""
if first=="ro" or first=="ru":
    rando="random"


for se1 in range(1,5):
    curr_dir=os.getcwd()+"/results"
    modelos=[]
    
    if first=="ro":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_random_"+ str(530+concat_number*(se1-1)+i))
            print(paths[-1])
            
        for path in paths:
            modelos.append(torch.load(path,map_location=torch.device('cpu')))
             
    elif first=="ru":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+under_model+"/"+"result_random_"+ str(530+concat_number*(se1-1)+i))
            print(paths[-1])
            
        for path in paths:
            modelos.append(torch.load(path,map_location=torch.device('cpu')))
                 
    elif first=="to":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_trained_"+str(30+concat_number*(se1-1)+i))
            print(paths[-1])
            
        for path in paths:
            modelos.append(torch.load(path,map_location=torch.device('cpu')))
                
    elif first=="tu":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+under_model+"/"+"result_trained_"+str(30+concat_number*(se1-1)+i))
            #print(paths[-1])
            
        for path in paths:
            print(path)
            modelos.append(torch.load(path,map_location=torch.device('cpu')))
            

    X_training_c=0
    X_test_c=0
    
    X_training_cs=[]
    X_test_cs=[]
    
    while(len(modelos)):
        modelo=modelos[0]
        modelos=modelos[1:]
	
        X_training=0
        X_test=0
        
        X_trainings=[]
        X_tests=[]
 
        modelo.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(trainloader):
              inputs, targets = inputs.to(device), targets.to(device)
              
              X_trainings.append(modelo.get_features(inputs).detach())
        X_training=torch.cat(X_trainings,dim=0)
        
        modelo.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(val_loader):
              inputs, targets = inputs.to(device), targets.to(device)
              X_tests.append(modelo.get_features(inputs).detach())
        
        X_test=torch.cat(X_tests,dim=0)
              
        del modelo
        
        X_training_cs.append(X_training)
        X_test_cs.append(X_test)
    
    X_test_c=torch.cat(X_test_cs, dim=1)
    X_training_c=torch.cat(X_training_cs, dim=1)
     
    X_training=X_training_c.cpu().numpy()
    del X_training_c
    
    X_test=X_test_c.cpu().numpy()
    del X_test_c


    #scaler=StandardScaler()
    #X_training = scaler.fit_transform(X_training)
    #X_test = scaler.transform(X_test)
    
    ######################################
    ##concatenate the target
    #########
    

    

    
    print(se1)
    print(first)
    print(second)
    
    modelus=[]
    
    curr_dir=os.getcwd()+"/results"
    if second=="ro":
        paths=[]
        for i in range(1,10+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_random_"+ str(500+i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu')))
            
         
    elif second=="ru":
        paths=[]
        for i in range(1,10+1):
            #paths.append(curr_dir + "/"+under_model+"/"+"result_random_"+ str(i))
            paths.append(curr_dir + "/"+under_model+"/"+"result_random_"+ str(500+i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu')))
            
        
    elif second=="to":
        paths=[]
        for i in range(1,10+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_trained_"+str(i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu')))
            
         
    elif second=="tu":
        paths=[]
        for i in range(1,10+1):
            paths.append(curr_dir + "/"+under_model+"/"+"result_trained_"+str(i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu')))
            

    
    Y_training_c=0
    Y_test_c=0
    
    Y_training_cs=[]
    Y_test_cs=[]
    
    while(len(modelus)):

        modelu=modelus.pop()	
    
        Y_trainings=[]
        Y_tests=[]

        modelu.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(trainloader):
              inputs, targets = inputs.to(device), targets.to(device)
              Y_trainings.append(modelu.get_features(inputs).detach())
        
        Y_training=torch.cat(Y_trainings,dim=0)
        
        modelu.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(val_loader):
              inputs, targets = inputs.to(device), targets.to(device)
              Y_tests.append(modelu.get_features(inputs).detach())
       
        Y_test=torch.cat(Y_tests,dim=0)
             
        del modelu
       
        Y_training_cs.append(Y_training)
        Y_test_cs.append(Y_test)
   
    Y_test_c=torch.cat(Y_test_cs, dim=1)
    Y_training_c=torch.cat(Y_training_cs, dim=1)
           
    print("finished")
                

    Y_training=Y_training_c.cpu().numpy()
    del Y_training_c
    Y_test=Y_test_c.cpu().numpy()
    del Y_test_c
    

    print("size before split")
    print(X_training[0].size)
    print(X_training[:,0].size)
    print(Y_training[0].size)
    print(Y_training[:,0].size)
    print(X_test[0].size)
    print(X_test[:,0].size)
    print(Y_test[0].size)
    print(Y_test[:,0].size)

    ####create training and test data for regression out of training and test data of CIFAR10
    X_training_reg1, X_test_reg1, Y_training_reg1, Y_test_reg1 = train_test_split(X_training, Y_training, train_size=0.85, shuffle=True, random_state=42)
    X_training_reg2, X_test_reg2, Y_training_reg2, Y_test_reg2 = train_test_split(X_test, Y_test, train_size=0.85, shuffle=True, random_state=42)    
    
    X_training = np.concatenate((X_training_reg1, X_training_reg2), axis=0)
    X_test = np.concatenate((X_test_reg1, X_test_reg2), axis=0)

    Y_training = np.concatenate((Y_training_reg1, Y_training_reg2), axis=0)
    Y_test = np.concatenate((Y_test_reg1, Y_test_reg2), axis=0)


    print("size after split")
    print(X_training[0].size)
    print(X_training[:,0].size)
    print(Y_training[0].size)
    print(Y_training[:,0].size)
    print(X_test[0].size)
    print(X_test[:,0].size)
    print(Y_test[0].size)
    print(Y_test[:,0].size)

    scaler=StandardScaler()
    X_training = scaler.fit_transform(X_training)
    X_test = scaler.transform(X_test)

    print("TIME:::::::::::::")
    print(time.time() - start)
    print("::::::::::::::::::::::")
            

    errors_validation=[]
    errors_training=[]
    errors_test=[]

    siz=len(Y_training[0])

    alphas= np.array([0.00001,0.0001,0.005,0.001,0.05,0.01,0.1,0.5,1,2,3,5,7,8,10,11,13,15,20,25,30,35,40,45,50,55,65,75,80,90,100,110,125,140,160,180,200,250,300,400,500,600,700,850,1000,1200,1400,1500,1600,1800,2000,2100,2200,2500,3000,3500,4000,4500,5000,6000,7000,8000,9000,10000,11000,12000,13000,14000,15000,16000,20000,30000])
    #nalphass=np.linspace(1000,10000,91)
    #alphas=np.concatenate((nalphass,alphass))
    #alphas=np.sort(alphas)
    #alphas=np.unique(alphas)
    

    scores = np.zeros((alphas.shape[0],Y_training[0].size))
    
    clf = RidgeCV(alphas=alphas, cv= None, alpha_per_target=True, scoring="r2").fit(X_training, Y_training)
     
    alphastars=clf.alpha_
    print("best_alpha_indeces")
    print(alphastars)
    
    errors_validation = list(1-clf.best_score_)
    errors_training = list(1-metrics.r2_score(Y_training, clf.predict(X_training), multioutput='raw_values'))
    errors_test = list(1-metrics.r2_score(Y_test, clf.predict(X_test), multioutput='raw_values'))
    
            
               
    #write validation
    curr_dir=os.getcwd()+"results/shuf_fse"
    curr_dir=curr_dir + rando +"/concat_"+str(concat_number) + "/Ridge/validation/"+ over_model +"/" + under_model+ "/" + first + "_" + second 
    
    if not os.path.exists(curr_dir): 
        os.makedirs(curr_dir)
    
    path=curr_dir +"/" +"result"+  "_" + str(se1) +".txt" 
    print(path)
    
    with open(path, 'w') as f:
        for e in errors_validation:
            f.write("%s\n" % e )

    #write training                 
    curr_dir=os.getcwd()+"results/shuf_fse"
    
    curr_dir=curr_dir+ rando +"/concat_"+str(concat_number) +"/Ridge/training/"+ over_model +"/" + under_model+ "/" + first + "_" + second 
    if not os.path.exists(curr_dir): 
        os.makedirs(curr_dir)
    
    path=curr_dir +"/" +"result"+  "_" + str(se1) +".txt" 
    print(path)
    
    with open(path, 'w') as f:
        for e in errors_training:
            f.write("%s\n" % e )   
    
    #write test                 
    curr_dir=os.getcwd()+"results/shuf_fse"
    
    curr_dir=curr_dir+ rando +"/concat_"+str(concat_number) +"/Ridge/test/"+ over_model +"/" + under_model+ "/" + first + "_" + second 
    if not os.path.exists(curr_dir): 
        os.makedirs(curr_dir)
    
    path=curr_dir +"/" +"result"+  "_" + str(se1) +".txt" 
    print(path)
    
    with open(path, 'w') as f:
        for e in errors_test:
            f.write("%s\n" % e )   

    print("TIME:::::::::::::")
    print(time.time() - start)
    print("::::::::::::::::::::::")

   
    
    

       
