
import torch
import sys
import numpy 
import random
import os
from tqdm import tqdm
from matplotlib import pyplot as plt
import pickle


n_examples = int(sys.argv[1])
hidden_size = int(sys.argv[2])
data_dimension = int(sys.argv[3])

total_layers = int(sys.argv[4])
learning_rate = float(sys.argv[5])
seed = int(sys.argv[6])

batch_size = 1000

num_epochs = 200
subset_size = int(sys.argv[11])
component_num = int(sys.argv[12])
max_degree = int(sys.argv[13])

raptr_layers_arg = raptr_layers = sys.argv[7]
raptr_stages_arg = raptr_stages = sys.argv[8]
layerdrop = float(sys.argv[9])
fig_name = sys.argv[10]

raptr_layers = [int(n) for n in raptr_layers.split('_')]
raptr_stages = [float(t) for t in raptr_stages.split('_')]

component_size = range(2, max_degree)

additional_name = str(learning_rate) + '_' + str(seed) + '_' + str(layerdrop) + '_' + raptr_layers_arg + "_" + raptr_stages_arg + "_" + str(total_layers) + "_" + str(hidden_size) + "_2_3_4_5_" + str(component_num) + "_" + str(subset_size) + "_" + str(max_degree)

#additional_name =  str(learning_rate) + '_' + str(seed) + '_' + str(layerdrop) + '_' + raptr_layers_arg + "_" + raptr_stages_arg + "_" + str(total_layers) + "_" + str(hidden_size) + "_2_3_4_5_" + str(component_num) + "_" + str(subset_size) + "_" + str(max_degree)

#if os.path.exists('Analysis_poly_simpler_Adam/Accuracy_'+str(fig_name)+'_' + additional_name + '_multilayers_quadratic_linearcoeff.pkl'):
#    exit(0)

def act_poly(x):
    return torch.nn.ReLU()(x)

def set_seed(seed: int = 42) -> None:
    numpy.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def get_features(d):
    all_features = [[[i] for i in range(subset_size)]] + [[numpy.random.choice(subset_size, size=(i,), replace=False) for _ in range(component_num)] for i in component_size]
    #print (all_features)
    return all_features

def get_coefficients():
    all_coefficients = [(1.) * numpy.random.normal(loc=0., scale=1., size=(subset_size,))] + [ numpy.random.normal(loc=0., scale=1., size=(component_num,)) for i in component_size]
    return all_coefficients

def true_predictions(data, all_coefficients, all_features):    
    out = numpy.zeros((len(data),))
    for alphas, features in zip(all_coefficients, all_features):
        for alpha, feature in zip(alphas, features):
            out += alpha * numpy.prod(data[:, feature], axis=-1)
    return out 

def correlations(data, output, all_features, all_coefficients):
    all_learned_alphas = []

    for features, alphas in zip(all_features, all_coefficients): 
        for feature, alpha in zip(features, alphas):
            data_feature = numpy.prod(data[:, feature], axis=-1)   
            learned_alpha = numpy.mean(output * data_feature)
            all_learned_alphas += [learned_alpha]
            
    return numpy.asarray(all_learned_alphas)

#1-d uniform data
def get_data(n, all_coefficients, all_features):
    data = 2*numpy.random.choice(2, size=(n, data_dimension))-1
    label = true_predictions(data, all_coefficients, all_features)
    return data, label

set_seed(seed)

all_coefficients = get_coefficients()
all_features = get_features(data_dimension)
    
# create a residual network here!

parameters = []
all_layers = []
for layer in range(total_layers):
    linear1 = torch.nn.Linear(data_dimension, hidden_size).to('cuda')
    linear2 = torch.nn.Linear(hidden_size, data_dimension).to('cuda')
    
    dtype = linear1.weight.dtype
    all_layers += [(linear1, linear2)]
    parameters += [linear1.weight, linear1.bias, linear2.weight, linear2.bias]
    
all_ln_layers = []    
for layer in range(total_layers):
    layernorm = torch.nn.LayerNorm(data_dimension,).to('cuda')
    all_ln_layers += [layernorm]
    parameters += [layernorm.weight, layernorm.bias]

    
classifier1 = torch.nn.Linear(data_dimension, 1).to('cuda')

final_layernorm = torch.nn.LayerNorm(data_dimension,).to('cuda')
parameters += [classifier1.weight, classifier1.bias, final_layernorm.weight, final_layernorm.bias]

optimizer = torch.optim.Adam(params=parameters, lr=learning_rate)


    

def network_predictions(input_, all_layers, n_layer):
    layers_to_keep = numpy.random.choice(total_layers, n_layer, replace=False)
      
    out = input_
    for layer in range(total_layers):
        if layer in layers_to_keep: skip = 1.
        else: skip = 0.  
        norm_out = all_ln_layers[layer](out)
        interm = all_layers[layer][1](act_poly(all_layers[layer][0](norm_out)))
        out = out + skip * interm
    norm_out = final_layernorm(out)
    out = classifier1(norm_out)
    return out




def loss_fn(pred, target):
    return torch.mean((pred[:, 0]-target)**2)

eval_plot = []
steps = []
total_steps = 0


# get the data
all_data, all_y = get_data(n_examples, all_coefficients, all_features)

train_split = len(all_data)
n_batches = train_split // batch_size
    
Learned_alphas = []    
PATH = 'Analysis_poly_simpler_Adam/model_states_new/state_dict_'+str(fig_name)+'_' + additional_name + '_multilayers_quadratic_samecoeff_400epochs'
#'Analysis_poly_simpler_Adam/model_states/state_dict_'+str(fig_name)+'_' + additional_name + '_multilayers_quadratic_linearcoeff'

Learned_alphas = []
eval_losses = []
for epoch in tqdm(range(0, 400, 5)):
    checkpoint = torch.load(PATH + '_epoch-'+ str(epoch))
    parameters = checkpoint['state_dict']
    iterator = 0
    #learned_alphas = numpy.zeros(len(all_features) * num_components)
    for layer in range(total_layers):
        all_layers[layer][0].load_state_dict(parameters[iterator])
        iterator += 1
        all_layers[layer][1].load_state_dict(parameters[iterator])
        iterator += 1
        
    for layer in range(total_layers):
        all_ln_layers[layer].load_state_dict(parameters[iterator])
        iterator += 1

    classifier1.load_state_dict(parameters[iterator])
    iterator += 1
    final_layernorm.load_state_dict(parameters[iterator])
    iterator += 1
        
    #now evaluate
    n_eval_batches = train_split // batch_size
    eval_loss = 0.
    for ebt in tqdm(range(n_eval_batches)):
        eval_batch_data, eval_batch_y = all_data[ebt*batch_size: (ebt+1)*batch_size], all_y[ebt*batch_size: (ebt+1)*batch_size]

        eval_cuda_batch_data, eval_cuda_batch_y = torch.tensor(eval_batch_data, device='cuda', dtype=dtype), torch.tensor(eval_batch_y, device='cuda', dtype=dtype)

        with torch.no_grad():
            predicted_y = network_predictions(eval_cuda_batch_data, all_layers, total_layers)
            loss = loss_fn(predicted_y, eval_cuda_batch_y)
            eval_loss += loss.item()
        
        try:
            learned_alphas += correlations(eval_batch_data, predicted_y.detach().cpu().numpy()[:, 0], all_features, all_coefficients)
        except:
            learned_alphas = correlations(eval_batch_data, predicted_y.detach().cpu().numpy()[:, 0], all_features, all_coefficients)
    
    #print (eval_loss/n_eval_batches)
    eval_losses += [eval_loss/n_eval_batches]
    learned_alphas /= n_eval_batches
    Learned_alphas += [numpy.copy(learned_alphas)]
    learned_alphas *= 0    
    
pickle.dump([Learned_alphas, all_features, all_coefficients, eval_losses], open('Moresample_Analysis_poly_simpler_Adam_400epochs/Correlations/Analysis_'+str(fig_name)+'_' + additional_name + '_multilayers_quadratic_samecoeff'+'_epoch-'+str(epoch)+'.pkl', 'wb'))    
