from run_experiments import load_data, train_network, mse_loss
from mlp import CentralMLP

import argparse
import numpy as np

def IST_Converge(weight, a, X, Y, num_neurons, num_subnets, init_scale, prob, 
        learning_rate, local_iter, mask_type, verbose=0):
    
    full_model = CentralMLP(num_neurons, X.shape[1], init_scale, prob)
    full_model.W = weight
    full_model.a = a
    error_hist = []
    error_hist.append(mse_loss(full_model.forward(X), Y).numpy())
    patience = 0
    best_avg = float('inf')
    gstep = 0
    learning_rate /= num_subnets
    avg_range = 20
    total_iter = 300
    if mask_type == 'Bernoulli':
        learning_rate *= num_subnets
        total_iter = 1000
        avg_range = 10
    while gstep < int(total_iter / local_iter):
        
        gstep += 1
        subnets, masks = full_model.generate_subnets(num_subnets, method=mask_type)
        for net_idx, subnet in enumerate(subnets):
            train_network(subnet, X, Y, learning_rate, local_iter)
        
        full_model.aggregate_updates(subnets, masks)
        error_hist.append(mse_loss(full_model.forward(X), Y).numpy() / X.shape[0])
        avg = np.mean(error_hist[np.max([0, gstep - avg_range]):])
        if avg < best_avg:
            best_avg = avg
        else:
            patience += 1
        
        if verbose > 0 and (gstep + 1) % verbose == 0:
            print('The Avg Error of Iteration %d is %f' % (gstep + 1, avg))
        
        if patience > (20 / (local_iter ** 0.3)):
            break
        
    return best_avg

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--num_neurons', type=int, default=100)
    parser.add_argument('-n', '--num_samples', type=int, default=100)
    parser.add_argument('-l', '--learning_rate', type=float, default=0.01)
    parser.add_argument('-e', '--local_iter', type=int, default=1)
    parser.add_argument('-k', '--init_scale', type=float, default=1.)
    parser.add_argument('-t', '--mask_type', type=str, choices=['Bernoulli', 'Categorical'], default='Bernoulli')
    parser.add_argument('-r', '--repetition', type=int, default=1)
    parser.add_argument('-f', '--result_fname', type=str, default='result.txt')
    args = parser.parse_args()
    
    X, Y = load_data(args.num_samples)
    weight = np.random.normal(scale=1, size=(X.shape[1], args.num_neurons))
    a = np.random.choice([-1, 1], size=(args.num_neurons,), p=[0.5, 0.5])
    
    if args.mask_type == 'Bernoulli':
        
        workers = np.arange(1, 11)
        xis = np.arange(1, 11)
        local_iter = 5
        heatmap_data = np.zeros((10, 10))
        for w in workers:
            for x in xis:
                best_avg = 0.
                for _ in range(args.repetition):
                    weight = np.random.normal(scale=1, size=(X.shape[1], args.num_neurons))
                    a = np.random.choice([-1, 1], size=(args.num_neurons,), p=[0.5, 0.5])
                    best_avg += IST_Converge(weight, a, X, Y, args.num_neurons, w,
                                        args.init_scale, x / 10, 
                                        args.learning_rate, local_iter,
                                        args.mask_type)
                heatmap_data[w - 1, x - 1] = best_avg / args.repetition
                print('Worker %d, Prob %f: %f' % (w, x / 10, heatmap_data[w - 1, x - 1]))
        np.savetxt(args.result_fname, heatmap_data)
    
    elif args.mask_type == 'Categorical':
        
        workers = np.arange(1, 11)
        local_iters = np.arange(1, 11)
        heatmap_data = np.zeros((10, 10))
        
        for w in workers:
            for local_iter in local_iters:
                best_avg = 0.
                for _ in range(args.repetition):
                    best_avg += IST_Converge(weight, a, X, Y, args.num_neurons, w,
                                        args.init_scale, 1. / w, 
                                        args.learning_rate, local_iter,
                                        args.mask_type)
                heatmap_data[w - 1, local_iter - 1] = best_avg / args.repetition
                print('Worker %d, Local iter: %d: %f' % (w, local_iter, heatmap_data[w - 1, local_iter - 1]))
        np.savetxt(args.result_fname, heatmap_data)    
