import argparse
import os
import shutil
from datetime import datetime

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import numpy as np
import train_util
import data_util


from sklearn.linear_model import RidgeCV
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn import metrics


import models as models



device = 'cuda' if torch.cuda.is_available() else 'cpu'

from sys import argv 



import torchvision
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch
import os
import numpy as np
import torch.nn.functional as F

import time



first=argv[1] 
second=argv[2]
under_model=argv[3]
over_model=argv[4]

concat_number=int(argv[5])
aug_on=argv[6]


if under_model=="vgg16_under_d2_bn" or under_model=="resnet18_under_d2_bn":
    if concat_number>10:
        quit()

if under_model=="vgg16_under_d4_bn" or under_model=="resnet18_under_d4_bn":
    if concat_number>16:
        quit()

if under_model=="vgg16_under_d8_bn" or under_model=="resnet18_under_d8_bn":
    if concat_number>64:
        quit()

if concat_number>10 and (first=="to" or first=="ro"):
    quit()


transform_train=0

print(aug_on)

if aug_on=="True":
    print("aug on")
    transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

else:
    transform_train = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Resize
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

path_data="./data/"
trainset = torchvision.datasets.CIFAR10(
                root=path_data, train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
            root=path_data, train=False, download=True, transform=transform_test)


trainloader = torch.utils.data.DataLoader(
                trainset, batch_size=100, shuffle=False, num_workers=4)

val_loader = torch.utils.data.DataLoader(
            testset, batch_size=100, shuffle=False, num_workers=4)  


start = time.time()

if (first=="to" and second=="to") or (first=="ro" and second=="ro"):
    if not (under_model=="resnet18_under_d16_bn" or under_model=="vgg16_under_d16_bn"):
         print("finised(to-to)")
         quit()    

print(first+second)

rando=""
if first=="ro" or first=="ru":
    rando="random"


for se1 in range(1,5):
    curr_dir=os.getcwd()+"/results"
    modelos=[]
    
    if first=="ro":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_random_"+ str(530+concat_number*(se1-1)+i)) if "vgg" in over_model else paths.append(curr_dir + "/"+over_model+"/"+"result_random_"+ str(30+concat_number*(se1-1)+i))
            print(paths[-1])
            
        for path in paths:
            modelos.append(torch.load(path,map_location=torch.device('cpu')))
  
    elif first=="ru":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+under_model+"/"+"result_random_"+ str(530+concat_number*(se1-1)+i)) if "vgg" in over_model else paths.append(curr_dir + "/"+under_model+"/"+"result_random_"+ str(30+concat_number*(se1-1)+i))
            print(paths[-1])
            
        for path in paths:
            modelos.append(torch.load(path,map_location=torch.device('cpu')))
                  
    elif first=="to":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_trained_"+str(30+concat_number*(se1-1)+i))
            print(paths[-1])
            
        for path in paths:
            modelos.append(torch.load(path,map_location=torch.device('cpu')))
               
    elif first=="tu":
        paths=[]
        for i in range(1,concat_number+1):
            paths.append(curr_dir + "/"+under_model+"/"+"result_trained_"+str(30+concat_number*(se1-1)+i))
            #print(paths[-1])
            
        for path in paths:
            print(path)
            modelos.append(torch.load(path,map_location=torch.device('cpu')))

    
    X_training_cs=[]
    X_test_cs=[]
    
    while(len(modelos)):
        modelo=modelos.pop()
	
        X_training=0
        X_test=0
        
        X_trainings=[]
        X_tests=[]
 
        modelo.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(trainloader):
              inputs, targets = inputs.to(device), targets.to(device)
              
              X_trainings.append(modelo.get_features(inputs).detach())
        X_training=torch.cat(X_trainings,dim=0)
        
        modelo.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(val_loader):
              inputs, targets = inputs.to(device), targets.to(device)
              X_tests.append(modelo.get_features(inputs).detach())
        
        X_test=torch.cat(X_tests,dim=0)
              
        del modelo
        
        X_training_cs.append(X_training)
        X_test_cs.append(X_test)
    
    X_test_c=torch.cat(X_test_cs, dim=1)
    X_training_c=torch.cat(X_training_cs, dim=1)
     
    X_training=X_training_c.cpu().numpy()
    del X_training_c
    
    X_test=X_test_c.cpu().numpy()
    del X_test_c


    scaler=StandardScaler()
    X_training = scaler.fit_transform(X_training)
    X_test = scaler.transform(X_test)
    
    ######################################
    ##concatenate the target
    #########
    
    print(se1)
    print(first)
    print(second)
    
    modelus=[]
    
    curr_dir=os.getcwd()+"/results"
    if second=="ro":
        paths=[]
        for i in range(1,10+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_random_"+ str(500+i)) if "vgg" in over_model else paths.append(curr_dir + "/"+over_model+"/"+"result_random_"+ str(i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu'))) 
                     
    elif second=="ru":
        paths=[]
        for i in range(1,10+1):
            paths.append(curr_dir + "/"+under_model+"/"+"result_random_"+ str(500+i)) if "vgg" in over_model else paths.append(curr_dir + "/"+under_model+"/"+"result_random_"+ str(i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu')))
                   
    elif second=="to":
        paths=[]
        for i in range(1,10+1):
            paths.append(curr_dir + "/"+over_model+"/"+"result_trained_"+str(i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu')))
                      
    elif second=="tu":
        paths=[]
        for i in range(1,10+1):
            paths.append(curr_dir + "/"+under_model+"/"+"result_trained_"+str(i))
            print(paths[-1])
       
        for path in paths:
            modelus.append(torch.load(path,map_location=torch.device('cpu')))


    Y_training_cs=[]
    Y_test_cs=[]
    
    while(len(modelus)):

        modelu=modelus.pop()	

        Y_trainings=[]
        Y_tests=[]

        modelu.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(trainloader):
              inputs, targets = inputs.to(device), targets.to(device)
              Y_trainings.append(modelu.get_features(inputs).detach())
        
        Y_training=torch.cat(Y_trainings,dim=0)
        
        modelu.eval()
        with torch.no_grad():
           for batch_idx, (inputs, targets) in enumerate(val_loader):
              inputs, targets = inputs.to(device), targets.to(device)
              Y_tests.append(modelu.get_features(inputs).detach())
       
        Y_test=torch.cat(Y_tests,dim=0)
             
        del modelu
       
        Y_training_cs.append(Y_training)
        Y_test_cs.append(Y_test)
   
    Y_test_c=torch.cat(Y_test_cs, dim=1)
    Y_training_c=torch.cat(Y_training_cs, dim=1)
           
    print("finished")
                

    Y_training=Y_training_c.cpu().numpy()
    del Y_training_c
    Y_test=Y_test_c.cpu().numpy()
    del Y_test_c
    

    print("size")
    print(X_training[0].size)
    print(Y_training[0].size)
    print(X_test[0].size)
    print(Y_test[0].size)

    end = time.time()
    print("TIME:version2::::::::::::")
    print(end - start)
    print("::::::::::::::::::::::")        

    errors_training=[]

    errors_test=[]

    siz=len(Y_training[0])

    alphass= np.array([0.00001,0.0001,0.005,0.001,0.01,1,5,10,15,20,30,50,75,100,110,125,140,160,180,200,250,300,400,500,600,700,850,1000,1200,1400,1500,1600,1800,2000,2100,2200,2500,3000,3500,4000,4500,5000,6000,7000,8000,9000,10000,11000,12000,13000,14000,15000,16000,20000,30000])
    nalphass=np.linspace(1000,10000,91)
    alphas=np.concatenate((nalphass,alphass))
    alphas=np.sort(alphas)
    alphas=np.unique(alphas)

    scores = np.zeros((alphas.shape[0],Y_training[0].size))
    clf = RidgeCV(alphas=alphas, cv=None, alpha_per_target=True,scoring="r2").fit(X_training, Y_training)

    alphastars=clf.alpha_
    print("best_alpha_indeces")
    print(alphastars)
    
    errors_training = list(1-clf.best_score_)
    errors_test = list(1-metrics.r2_score(Y_test, clf.predict(X_test), multioutput='raw_values'))
    
               
    #write training
    curr_dir=os.getcwd()+"results/fse"
    curr_dir=curr_dir+ rando +"/concat_"+str(concat_number) + "/Ridge/Reg/"+ over_model +"/different_seed/"+ under_model+ "/" + first + "_" + second 
    
    if not os.path.exists(curr_dir): 
        os.makedirs(curr_dir)
    
    path=curr_dir +"/" +"result"+  "_" + str(se1) +".txt" 
    print(path)
    
    with open(path, 'w') as f:
        for e in errors_training:
            f.write("%s\n" % e )
    
    #write test                 
    curr_dir=os.getcwd()+"results/fse"
    
    curr_dir=curr_dir+ rando +"/concat_"+str(concat_number) +"/Ridge/test/Reg/"+ over_model +"/different_seed/"+ under_model+ "/" + first + "_" + second 
    if not os.path.exists(curr_dir): 
        os.makedirs(curr_dir)
    
    path=curr_dir +"/" +"result"+  "_" + str(se1) +".txt" 
    print(path)
    
    with open(path, 'w') as f:
        for e in errors_test:
            f.write("%s\n" % e )   

    end = time.time()
    print("TIME:version2::::::::::::")
    print(end - start)
    print("::::::::::::::::::::::")

   
    
    

       
