
import torch
import sys
import numpy 
import random
import os
from tqdm import tqdm
from matplotlib import pyplot as plt

n_examples = int(sys.argv[1])
hidden_size = int(sys.argv[2])
data_dimension = int(sys.argv[3])

total_layers = int(sys.argv[4])
learning_rate = float(sys.argv[5])
seed = int(sys.argv[6])

batch_size = 1024

num_epochs = 400
subset_size = int(sys.argv[11])

component_num = int(sys.argv[12])
max_degree = int(sys.argv[13])
component_size = range(1, max_degree)

raptr_layers_arg = raptr_layers = sys.argv[7]
raptr_stages_arg = raptr_stages = sys.argv[8]
layerdrop = float(sys.argv[9])
fig_name = sys.argv[10]

raptr_layers = [int(n) for n in raptr_layers.split('_')]
raptr_stages = [float(t) for t in raptr_stages.split('_')]


additional_name = str(learning_rate) + '_' + str(seed) + '_' + str(layerdrop) + '_' + raptr_layers_arg + "_" + raptr_stages_arg + "_" + str(total_layers) + "_" + str(hidden_size) + "_2_3_4_5_" + str(component_num) + "_" + str(subset_size) + "_" + str(max_degree)

#if os.path.exists('Accuracy_poly_simpler_Adam/Accuracy_'+str(fig_name)+'_' + additional_name + '_multilayers_quadratic_samecoeff_400epochs_randomized.pkl'):
#    exit(0)

def act_poly(x):
    return torch.nn.ReLU()(x)

def set_seed(seed: int = 42) -> None:
    numpy.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def get_features(d):
    #[[[i] for i in range(subset_size)]] +
    all_features =  [[numpy.random.choice(subset_size, size=(i,), replace=False) for _ in range(component_num)] for i in component_size]
    #print (all_features)
    return all_features

def get_coefficients():
    #[ numpy.random.normal(loc=0., scale=1., size=(subset_size,))] 
    all_coefficients = [ numpy.random.normal(loc=0., scale=1., size=(component_num,)) for i in component_size]
    return all_coefficients

def true_predictions(data, all_coefficients, all_features):
    
    #print (len(all_coefficients), len(all_features))
    out = numpy.zeros((len(data),))
    for alphas, features in zip(all_coefficients, all_features):
        for alpha, feature in zip(alphas, features):
            #print (alpha, feature)
            out += alpha * numpy.prod(data[:, feature], axis=-1)
    #exit(0)
    return out 


#1-d uniform data
def get_data(n, all_coefficients, all_features):
    data = 2*numpy.random.choice(2, size=(n, data_dimension))-1
    label = true_predictions(data, all_coefficients, all_features)
    return data, label

set_seed(seed)

all_coefficients = get_coefficients()
all_features = get_features(data_dimension)
    
# create a residual network here!

parameters = []
all_layers = []
for layer in range(total_layers):
    linear1 = torch.nn.Linear(data_dimension, hidden_size).to('cuda')
    linear2 = torch.nn.Linear(hidden_size, data_dimension).to('cuda')
    
    dtype = linear1.weight.dtype
    all_layers += [(linear1, linear2)]
    parameters += [linear1.weight, linear1.bias, linear2.weight, linear2.bias]
    
all_ln_layers = []    
for layer in range(total_layers):
    layernorm = torch.nn.LayerNorm(data_dimension,).to('cuda')
    all_ln_layers += [layernorm]
    parameters += [layernorm.weight, layernorm.bias]

    
classifier1 = torch.nn.Linear(data_dimension, 1).to('cuda')

final_layernorm = torch.nn.LayerNorm(data_dimension,).to('cuda')
parameters += [classifier1.weight, classifier1.bias, final_layernorm.weight, final_layernorm.bias]

SGD = torch.optim.Adam(params=parameters, lr=learning_rate)


    

def network_predictions(input_, all_layers, n_layer):
    layers_to_keep = numpy.random.choice(total_layers, n_layer, replace=False)
      
    out = input_
    for layer in range(total_layers):
        if layer in layers_to_keep: skip = 1.
        else: skip = 0.  
        norm_out = all_ln_layers[layer](out)
        interm = all_layers[layer][1](act_poly(all_layers[layer][0](norm_out)))
        out = out + skip * interm
    norm_out = final_layernorm(out)
    out = classifier1(norm_out)
    return out

def layerdrop_predictions(input_, all_layers, p_layer):
        
    out = input_
    for layer in range(total_layers):
        skip = numpy.random.choice(2, size=(1,), p=[1.-p_layer[layer], p_layer[layer]])[0] / p_layer[layer] 
        norm_out = all_ln_layers[layer](out)
        interm = all_layers[layer][1](act_poly(all_layers[layer][0](norm_out)))
        out = out + skip * interm
    norm_out = final_layernorm(out)
    out = classifier1(norm_out)      
    return out




def loss_fn(pred, target):
    return torch.mean((pred[:, 0]-target)**2)

eval_plot = []
steps = []
total_steps = 0

eval_split = 5000
all_steps = ((n_examples - eval_split) // batch_size) * num_epochs

# get the data
all_data, all_y = get_data(n_examples, all_coefficients, all_features)
train_split = len(all_data) - eval_split
eval_data, eval_y = all_data[train_split:], all_y[train_split:]
train_data, train_y = all_data[:train_split], all_y[:train_split]  
n_batches = train_split // batch_size
    
for epoch in range(num_epochs):
    train_data, train_y = get_data(len(train_data), all_coefficients, all_features)

    for bt in tqdm(range(n_batches)):
        batch_data, batch_y = train_data[bt*batch_size: (bt+1)*batch_size], train_y[bt*batch_size: (bt+1)*batch_size]
        cuda_batch_data, cuda_batch_y = torch.tensor(batch_data, device='cuda', dtype=dtype), torch.tensor(batch_y, device='cuda', dtype=dtype)

        if layerdrop > 0.:
            prob_drop = [layerdrop * ((1. * i)/(total_layers-1)) for i in range(total_layers)]
            prob_keep = [1.- p + p * numpy.exp(-100 * total_steps / (1. * all_steps)) for p in prob_drop]
            #print (prob_keep)
            predicted_y = layerdrop_predictions(cuda_batch_data, all_layers, prob_keep)
        else:
            frst_ind = 0
            while((total_steps / (1. * all_steps)) > raptr_stages[frst_ind]):
                frst_ind += 1
            predicted_y = network_predictions(cuda_batch_data, all_layers, raptr_layers[frst_ind])
        loss = loss_fn(predicted_y, cuda_batch_y)
        loss.backward()
        SGD.step()
        SGD.zero_grad()
        total_steps += 1

        if total_steps % 200 == 0:
            n_eval_batches = 1
            eval_loss = 0.
            for ebt in range(n_eval_batches):
                eval_batch_data, eval_batch_y = eval_data, eval_y

                eval_cuda_batch_data, eval_cuda_batch_y = torch.tensor(eval_batch_data, device='cuda', dtype=dtype), torch.tensor(eval_batch_y, device='cuda', dtype=dtype)
                with torch.no_grad():
                    predicted_y = network_predictions(eval_cuda_batch_data, all_layers, total_layers)
                    loss = loss_fn(predicted_y, eval_cuda_batch_y)
                    eval_loss += loss.item()

            eval_plot += [eval_loss]
            steps += [total_steps]
            #print (eval_loss)
    
    if epoch % 5 == 0:
        save_parameters = []
        for layer in range(total_layers):
            save_parameters += [all_layers[layer][0].state_dict(), all_layers[layer][1].state_dict()]

        for layer in range(total_layers):
            save_parameters += [all_ln_layers[layer].state_dict()]

        save_parameters += [classifier1.state_dict()]
        save_parameters += [final_layernorm.state_dict()]

        PATH = 'Analysis_poly_simpler_Adam/model_states_new/state_dict_'+str(fig_name)+'_' + additional_name + '_multilayers_quadratic_samecoeff_400epochs_randomized'
        torch.save({
            'epoch': epoch,
            'state_dict': save_parameters,
        }, PATH+'_epoch-'+str(epoch))

        
        
n_eval_batches = 1
eval_loss = 0.
for ebt in range(n_eval_batches): 
    batch_data, batch_y = eval_data, eval_y

    cuda_batch_data, cuda_batch_y = torch.tensor(batch_data, device='cuda', dtype=dtype), torch.tensor(batch_y, device='cuda', dtype=dtype)
    with torch.no_grad():
        predicted_y = network_predictions(cuda_batch_data, all_layers, total_layers)
        loss = loss_fn(predicted_y, cuda_batch_y)
        eval_loss += loss.item()

eval_plot += [eval_loss]
steps += [total_steps]
    

import pickle
#additional_name = str(learning_rate) + '_' + str(seed) + '_' + str(layerdrop) + '_' + raptr_layers_arg + "_" + raptr_stages_arg + "_" + str(total_layers) + "_" + str(hidden_size) + "_2_3_4_5_" + str(component_num) + "_" + str(subset_size)

pickle.dump([eval_plot, steps], open('Accuracy_poly_simpler_Adam/Accuracy_'+str(fig_name)+'_' + additional_name + '_multilayers_quadratic_samecoeff_400epochs_randomized.pkl', 'wb'))                

    
    
    
    