import torch
import torch.nn as nn

import numpy as np

import math
import time
import argparse
import os

from train_model import training_setup, train
from dataset import read_dataset

def main(dataset, noises, num_trials, epochs, lr, patience, factor):
    torch.manual_seed(3)

    def get_train_results(X, y):
        batch_size = X.shape[0]

        model, optimizer, scheduler = training_setup(X, lr, factor, patience)
        loss = train(model, optimizer, scheduler, X, y, epochs, batch_size, reg=None)

        return model, loss

    expe_norms = []
    draw_norms = []
    expe_ws = []
    expe_bs = []
    draw_ws = []
    draw_bs = []
    for noise in noises:
        if noise != 0.0:
            X, y0, y1 = read_dataset(f'exponential_loss_csvs/datasets/{dataset}/noise_{noise}')
        else:
            X, y0, y1 = read_dataset(f'exponential_loss_csvs/datasets/{dataset}')
        X = torch.tensor(X)
        y1 = torch.tensor(y1)

        model, loss = get_train_results(X, y1)
        expe_norms.append(model.norm().item())
        expe_ws.append(model.w.detach().numpy())
        expe_bs.append(model.b.detach().numpy())

        draw_norm = []
        draw_w = []
        draw_b = []
        for trial in range(num_trials):
            if noise != 0.0:
                X, y0, y1 = read_dataset(f'exponential_loss_csvs/datasets/{dataset}/noise_{noise}/trial_{trial}')
            else:
                X, y0, y1 = read_dataset(f'exponential_loss_csvs/datasets/{dataset}')
            X = torch.tensor(X)
            y1 = torch.tensor(y1)

            model, loss = get_train_results(X, y1)
            draw_norm.append(model.norm().item())
            draw_w.append(model.w.detach().numpy())
            draw_b.append(model.b.detach().numpy())
        draw_norms.append(draw_norm)
        draw_ws.append(draw_w)
        draw_bs.append(draw_b)
        
    expe_norms = np.array(expe_norms)
    draw_norms = np.array(draw_norms)
    expe_ws = np.array(expe_ws)
    expe_bs = np.array(expe_bs)
    draw_ws = np.array(draw_ws)
    draw_bs = np.array(draw_bs)
    
    if not os.path.exists(f'./results'):
        os.mkdir(f'./results')

    np.save(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_expe', expe_norms)
    np.save(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_draw', draw_norms)
    np.save(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_expe_weight', expe_ws)
    np.save(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_expe_bias', expe_bs)
    np.save(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_draw_weight', draw_ws)
    np.save(f'./results/{dataset}_{noises}_{num_trials}_{epochs}_draw_bias', draw_bs)


def make_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("-d", "--dataset", type=str, required=True, help="Dataset to use as a string")

    return parser

if __name__ == '__main__':
    parser = make_parser()
    args = parser.parse_args()
    print(args)

    dataset = args.dataset
    
    noises = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
    num_trials = 100

    lr = 0.1
    patience = 50 # how often to decay learning rate
    factor = 0.3 # how much to decay learning rate

    epochs = 1000

    start_time = time.time()
    main(dataset, noises, num_trials, epochs, lr, patience, factor)
    print('Time:', time.time() - start_time)
