# training scripts

import os
import copy
import argparse
import pickle

from tqdm import tqdm
import torch

from data import get_mnist, get_cifar
from models import TwoLayerMLP, MultiClassMLP
from utils import generate_feature_map, check_in_hull
from pruning import prune_mlp_fixed_size


def train_mlp():
    # parse all the arguments for the script
    parser = argparse.ArgumentParser(description='MLP Training')
    parser.add_argument('--dataset', type=str, default='mnist')
    parser.add_argument('--use-step-sched', action='store_true', default=False)
    parser.add_argument('--lr', type=float, default=1e-6, metavar='LR')
    parser.add_argument('--num-hidden', type=int, default=20000)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--trn-size', type=int, default=50000)
    parser.add_argument('--validate', action='store_true', default=False)
    parser.add_argument('--prune-epochs', type=int, default=0)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--verbose', action='store_true', default=False)
    parser.add_argument('--prune', type=str, default='last_epoch')
    parser.add_argument('--prune-by-iter', action='store_true', default=False)
    parser.add_argument('--check-hull', action='store_true', default=False)
    parser.add_argument('--downsample-method', type=str, default='uniform')
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--prune-freq', type=int, default=1)
    parser.add_argument('--prune-size', type=int, default=500)
    parser.add_argument('--exp-name', type=str, default=None)
    args = parser.parse_args()

    # get the data
    if args.dataset == 'mnist':
        if args.downsample_method == None:
            downsample_params = None
        else:
            downsample_params = {
                'size': args.trn_size,
                'method': args.downsample_method,
            }
        input_size = 784
        trn_dl, test_dl = get_mnist(
                batch_size=args.batch_size, downsample_params=downsample_params)
    elif args.dataset == 'cifar10_smallimg':
        if args.downsample_method == None:
            downsample_params = None
        else:
            downsample_params = {
                'size': args.trn_size,
                'method': args.downsample_method,
            }
        input_size = 972 # cifar is downsampled to a size of 18
        trn_dl, test_dl = get_cifar(
                batch_size=args.batch_size, augment=False, small_img=True,
                binarize=True, downsample_params=downsample_params)
    elif args.dataset == 'cifar10':
        if args.downsample_method == None:
            downsample_params = None
        else:
            downsample_params = {
                'size': args.trn_size,
                'method': args.downsample_method,
            }
        input_size = 3072
        trn_dl, test_dl = get_cifar(
                batch_size=args.batch_size, augment=False, small_img=False,
                binarize=True, downsample_params=downsample_params)
    else:
        raise NotImplementedError('Not a supported dataset')

    # check if there is a GPU available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # construct the model and optimizer
    model = TwoLayerMLP(input_size, args.num_hidden)
    model = model.to(device)
    opt = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum)
    criterion = torch.nn.BCELoss()

    # attempt pruning before training
    prune_accs = None
    best_prune_model = None
    if args.prune is not None and args.prune == 'all_epochs':
        prune_model, prune_acc = prune_mlp_fixed_size(
                args, model, trn_dl, args.prune_size)
        prune_accs = [prune_acc]
        best_prune_model = copy.deepcopy(prune_model)  
 
    # main training loop
    iter_accs = []
    iter_losses = []
    trn_losses = []
    trn_accs = []
    test_accs = []
    agg_iter = 0
    for e in range(args.epochs):
        # optionally perform step schedule on the learning rate
        if args.use_step_sched:
            if e == int(0.5*args.epochs):
                if args.verbose:
                    print(f'Reducing LR 10X at Epoch #{e}')
                for g in opt.param_groups:
                    g['lr'] = g['lr'] / 10

        # run a single epoch of training
        if args.verbose:
            print(f'Running Epoch {e + 1} / {args.epochs}')

        model = model.to(device)
        agg_trn_loss = 0.
        num_corr = 0.
        num_ex = 0.
        for it, (model_in, target) in enumerate(trn_dl):
            opt.zero_grad()
            target = target.float()
            model_in = torch.flatten(model_in, start_dim=1)
            model_in = model_in.to(device)
            target = target.to(device)
            output = model(model_in).squeeze(dim=1)
            loss = criterion(output, target)
            loss.backward()
            opt.step()
            preds = output > 0.5
            tmp_num_corr = float(torch.sum(preds.long() == target))
            num_corr += tmp_num_corr
            num_ex += target.shape[0]
            agg_trn_loss += loss.item()

            # track loss and acc within each iteration
            iter_accs.append((tmp_num_corr / target.shape[0]))
            iter_losses.append(loss.item())

            # assume you are pruning by iteration if prune frequency is large
            agg_iter += 1
            if args.prune is not None and args.prune == 'all_epochs':
                if (args.prune_freq > args.epochs or args.prune_by_iter) and agg_iter % args.prune_freq == 0:
                        prune_model, prune_acc = prune_mlp_fixed_size(
                                args, model, trn_dl, args.prune_size)
                        if prune_acc > max(prune_accs):
                            best_prune_model = copy.deepcopy(prune_model)  
                        prune_accs.append(prune_acc)
                        model = model.to(device) # ensure model on GPU
        
        # print out training metrics at each iteration
        agg_trn_loss /= len(trn_dl)
        agg_trn_acc = num_corr / num_ex
        trn_losses.append(agg_trn_loss)
        trn_accs.append(agg_trn_acc)
        if args.verbose:
            print(f'Training Loss: {agg_trn_loss:.4f}')
            print(f'Training Acc.: {agg_trn_acc:.4f}')

        # continue iterating until convex hull assumption is satisfied 
        if args.check_hull:
            if check_in_hull(trn_dl, model):
                print(f'Convex hull assumption satisfied at epoch #{e + 1}')
                return
        

        # evaluate model performance on test set
        if args.validate:
            num_ex = 0.
            num_corr = 0.
            for model_in, target in test_dl:
                with torch.no_grad():
                    model_in = torch.flatten(model_in, start_dim=1)
                    model_in = model_in.to(device)
                    target = target.to(device)
                    output = model(model_in).squeeze(dim=1)
                    preds = output > 0.5
                    num_corr += float(torch.sum(preds.long() == target))
                    num_ex += target.shape[0]
            agg_test_acc = num_corr / num_ex
            test_accs.append(agg_test_acc)
            if args.verbose:
                print(f'Test Acc.: {agg_test_acc:.4f}')

        # try to prune the model after every epoch
        if args.prune is not None and args.prune == 'all_epochs':
            if args.prune_freq == 1 and not args.prune_by_iter:
                prune_model, prune_acc = prune_mlp_fixed_size(
                        args, model, trn_dl, args.prune_size)
                if prune_acc > max(prune_accs):
                    best_prune_model = copy.deepcopy(prune_model)
                prune_accs.append(prune_acc)
            elif args.prune_freq < args.epochs and not args.prune_by_iter:
                # optionally only prune every few epochs
                assert args.prune_freq > 1
                if (e + 1) % args.prune_freq == 0 or (e + 1) >= args.epochs:
                    prune_model, prune_acc = prune_mlp_fixed_size(
                            args, model, trn_dl, args.prune_size)
                    if prune_acc > max(prune_accs):
                        best_prune_model = copy.deepcopy(prune_model)  
                    prune_accs.append(prune_acc)
            elif (e + 1) >= args.epochs:
                prune_model, prune_acc = prune_mlp_fixed_size(
                        args, model, trn_dl, args.prune_size)
                if prune_acc > max(prune_accs):
                    best_prune_model = copy.deepcopy(prune_model)  
                prune_accs.append(prune_acc)
    
    # prune after the final epoch if specified
    if args.prune is not None and args.prune == 'last_epoch':
        prune_model, prune_acc = prune_mlp_fixed_size(
                args, model, trn_dl, args.prune_size)
        prune_accs = [prune_acc]
        best_prune_model = copy.deepcopy(prune_model)  

    if args.exp_name is not None:
        if not os.path.exists('./results/'):
            os.mkdir('./results/')
        fp = os.path.join('./results', args.exp_name + '.pckl')
        all_result = {
            'iter_accs': iter_accs,
            'iter_losses': iter_losses,
            'trn_loss': trn_losses,
            'trn_acc': trn_accs,
            'test_acc': test_accs,
            'prune_acc': prune_accs,
        }
        with open(fp, 'wb') as f:
            pickle.dump(all_result, f)
            
if __name__=='__main__':
    train_mlp()
