import torch
import torch.nn as nn
import torch.utils.data
import torchvision
from torchvision import datasets, transforms
import numpy as np
import random
import argparse
import os
import subprocess
import pickle
import importlib
import time
from trainer.CustomSummaryWriter import *
from trainer.nets import FFNN_sparse_NTK
from trainer.utils import *

def str2bool(s):
    if s.lower()=='true':
        return True
    elif s.lower()=='false':
        return False
    else:
        raise RuntimeError('bool arg conversion failed!')


if __name__ == '__main__':

    start_time = time.time()

    parser = argparse.ArgumentParser()
    parser.add_argument('--no-cuda', default=False, action='store_true',
                        help='disables CUDA')
    parser.add_argument('--no-bn', default=False, action='store_true',
                        help='disables BatchNorm')
    # deprecated
    #parser.add_argument('--no-sinit', default=False, action='store_true', 
    #                    help='disables adjusted init for sparsified layer')
    parser.add_argument('--no_ES', default=False, action='store_true',
                        help='disable Early Stopping')
    parser.add_argument('--no-da', default=False, action='store_true',
                        help='disables Data Augmentation on CIFAR10')
    parser.add_argument('--make_linear', default=False, action='store_true', 
                        help='do not apply activation function')
    parser.add_argument('--NTK_style', default=False, action='store_true', 
                        help='use initialization and forward pass in NTK style')
    parser.add_argument('--scheduler_on', default=False, action='store_true', 
                        help='apply LR scheduler')
    parser.add_argument('--train_cl_only', default=False, action='store_true', 
                        help='train only the classifier')
    parser.add_argument('--max_epochs', default=1, type=int, help='number of epochs (default: 1')
    parser.add_argument('--resume', default=False, action='store_true', 
                        help='resume from checkpoint')
    parser.add_argument('--dataset', default='MNIST', type=str, 
                        help='dataset')
    parser.add_argument('--dataset_dir', default='./data', type=str, help='dataset directory')
    parser.add_argument('--data_normalization', default='proper', type=str, 
                        help='normalization for Fashion/MNIST dataset; options: proper, proper2, 0505, none, pixelwise')
    parser.add_argument('--num0_input', default=0, type=int, help='number of pixels in input image to set to zero (default: 0)')

    parser.add_argument('--num_hidden_layers', default=1, type=int, help='number of hidden layers (default: 1)')
    parser.add_argument('--no_bias', default=False, action='store_true', help='no bias in the layers')
    parser.add_argument('--Nh_base', default=56, type=int,
                        help='number of units in the hidden layer in the base case (i.e., dense model) (default: 56)')
    parser.add_argument('--Nh', default=56, type=int,
                        help='number of units in the hidden layer in the given (sparse) model (default: 56)')
    parser.add_argument('--num_to_freeze_cl', default=0, type=int,
                        help='number of weights to freeze in cl layer')
    parser.add_argument('--num_to_freeze_fc', default=0, type=int,
                        help='number of weights to freeze in fc layer')  
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--mbs', default=100, type=int, help='mini-batch size')
    parser.add_argument('--init_distrib', default='uniform', type=str, 
                        help='probability distribution for parameter initialization; options: uniform, normal')
    # deprecated
    #parser.add_argument('--single_batch_training', default='false', type=str2bool, 
    #                    help='train on a single batch only (MNIST xperiment)')
    parser.add_argument('--train_subset_size', default=0, type=int,
                        help='number of samples if training on a subset of the original train set')
    parser.add_argument('--bucket', default='bucket2', type=str, help='my bucket address')
    parser.add_argument('--seed', default=4567, type=int, help='random seed')
    parser.add_argument('--pre_dir', default='unspecified_dir', type=str, 
                        help='folder name indicating new param setting')

    args = parser.parse_args()
    dargs = vars(args)
    print(dargs)

    # ==== device configuration
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # set random seed for each job individually
    #seed = random.randint(100, 999) # haha
    seed=args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    #print(f'>> seed: {seed}')

    # ========== training and dataset hyper-params ==========
    # =======================================================

    dataset = args.dataset
    no_da   = args.no_da
    dataset_dir = args.dataset_dir
    data_normalization= args.data_normalization
    num0 = args.num0_input

    no_ES=args.no_ES
    learning_rate = args.lr

    ckpt_every =  25
    max_epochs = args.max_epochs     # cut-off value for train loop


    # ========== load dataset ==========
    # ==================================

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    train_subset_size=args.train_subset_size
    if train_subset_size>0: # training on a subset
        batch_size= train_subset_size
    else: # training on original whole train set
        batch_size= args.mbs
    test_batch_size = 1000

    train_loader, train_loader_for_eval, test_loader, input_size, num_classes =\
        load_dataset(dataset, dataset_dir, batch_size, test_batch_size, 
                     data_normalization, no_da, kwargs)
    batch_size= args.mbs

    # ========== model hyper-params ==========
    # ========================================

    # BatchNorm
    do_batch_norm = not args.no_bn

    # model arch
    init_distrib= args.init_distrib
    make_linear= args.make_linear
    num_hidden_layers = args.num_hidden_layers
    net_type= 'Lin' if num_hidden_layers==0 else 'ffnn'
    add_bias= not args.no_bias
    NTK_style= args.NTK_style
    NTK_tag='_NTK_style' if NTK_style else ''
    
    Nh_base= args.Nh_base
    Nh= args.Nh

    lkeys= ['fc', 'cl']

    # sparsity
    num_to_freeze_fc=args.num_to_freeze_fc
    num_to_freeze_cl=args.num_to_freeze_cl
    ctvt_total=(Nh_base/Nh) if Nh>0 else 1
    sparse=True if ctvt_total<1 else False


    pre_dir=args.pre_dir
    metrics_savedir = compose_name_for_output(net_type, Nh_base, Nh, input_size, num_classes, 
                                                num_to_freeze_fc, num_to_freeze_cl, 
                                                make_linear, NTK_style, 
                                                dataset, batch_size, learning_rate,
                                                pre_dir, seed)


    writer = CustomSummaryWriter(metrics_savedir, args.bucket)


    # ========== set up model ==========
    # ==================================
    if num_hidden_layers==0:
        model = Lin(input_size, num_classes, init_distrib, add_bias).to(device)
    else:
        model = FFNN_sparse_NTK(input_size, num_classes, Nh, Nh_base, init_distrib,
                            num_to_freeze_fc, num_to_freeze_cl,
                            do_batch_norm, make_linear, NTK_style, add_bias).to(device)
    print(model)
    
    if args.train_cl_only:
        lkey='fc'
        model._modules['layers'][lkey].weight.requires_grad=False
        if add_bias:
            model._modules['layers'][lkey].bias.requires_grad=False
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
        print('training only the classifier')
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # momentum 0.9 for CNN

    # loss and optimizer
    criterion = nn.CrossEntropyLoss()

    if args.scheduler_on:
        lr_milestones=[60, 140, 230] #[40,100,150] #[50,100,150]
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=0.1) #learning rate decay


    if sparse:
        # ==== get smask from model ====
        smask={}
        for lkey in lkeys:
            smask[lkey]= model._modules['layers'][lkey].weight==0
        

    if args.resume:
        # ==== load model checkpoint and resume training ====

        model_dir= f'checkpoints/{pre_dir}'
        model_fname = metrics_savedir.strip(pre_dir).strip('/')
        model_fname+= '_final.ckpt'
        full_checkpoint_path=f'{model_dir}/{model_fname}'

        print(f'... gonna load checkpoint from {full_checkpoint_path}')
        checkpoint= load_checkpoint(model_dir, model_fname)
        start_epoch= checkpoint['epoch']

        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        print(f'=> successfully loaded checkpoint {model_fname} (epoch {checkpoint["epoch"]})!')

        # rename old stats file, otherwise it'll be overwritten
        stats_fname_old=f'runs/{metrics_savedir}/stats.pth'
        stats_fname_new=f'runs/{metrics_savedir}/stats_prev_{checkpoint["epoch"]}.pth'
        os.rename(stats_fname_old, stats_fname_new)

        if sparse:
            # ==== get smask from model ====
            smask={}
            for lkey in lkeys:
                smask[lkey]= model._modules['layers'][lkey].weight==0
    else:
        # ======== save init model checkpoint
        start_epoch=0
        state= {'epoch': start_epoch, 'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(), 'args': args}
        save_name = f'{metrics_savedir}_init'
        sshproc = save_checkpoint(state, save_name)

    # ============= train ==============
    # ==================================
    train_loss_inloop=1  #init
    train_acc_inloop=0 #init
    epoch=start_epoch+1 #init
    best_test_acc=0 #init
    patience= 20
    test_acc_tracker=list(np.zeros(2*patience)) # keep a list of test acc over past some eps
    model.train()


    if train_subset_size>0:
        images_, labels_ = next(iter(train_loader))
        new_train_set = torch.utils.data.TensorDataset(images_, labels_)
        train_loader = torch.utils.data.DataLoader(new_train_set, batch_size=batch_size, shuffle=True)


    while epoch<max_epochs:

        # ======== train cycle start
        #|
        loss_sum, total, correct = 0, 0, 0
        for i, (images, labels) in enumerate(train_loader):

            images= images.reshape(-1, input_size).to(device)
            if data_normalization=='pixelwise': images= normalize_pixelwise(images) 
            
            if num0>0:
                inds_for_batch = torch.LongTensor([random.sample( range(input_size), num0 ) for _ in range(batch_size)])
                images[torch.arange(images.size(0)).unsqueeze(1), inds_for_batch] = 0
                assert torch.sum(images[2]==0)==num0, "error: images not masked properly in train loop!"

            labels = labels.to(device)

            # ==== forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_sum += len(images)*loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted==labels).cpu().sum().item()
            total += len(images)
            # ==== backward and optimize
            optimizer.zero_grad()
            
            loss.backward()  # compute gradients
            if sparse: # apply smask to gradients
                for lkey in lkeys:
                    if smask[lkey] is not None: # smask is None if layer does not have to be sparsified
                        #layer= getattr(model,lkey)
                        layer= model._modules['layers'][lkey]
                        if torch.isnan(layer.weight.grad).any() == True:
                            print(f'NaNs in grads of {lkey} before zeroing')

                        layer.weight.grad[ smask[lkey] ] = 0

                        if torch.isnan(layer.weight.grad).any() == True:
                            print(f'NaNs in grads of {lkey} after zeroing')
            
            optimizer.step() # update parameters
        #|
        # ======== train cycle end (one epoch completed)

        train_loss_inloop = loss_sum/total
        train_acc_inloop = correct/total

        # ======== save model checkpoint every some epochs
        if epoch%ckpt_every==0:
            state= {'epoch': epoch, 'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(), 'args': args}
            save_name = f'{metrics_savedir}_epoch_{epoch}'
            sshproc = save_checkpoint(state, save_name)

        # ======== evaluate ========
        # 
        test_acc, test_loss = evaluate(model, test_loader, data_normalization, num0, input_size, device, criterion)
        train_acc, train_loss = evaluate(model, train_loader_for_eval, data_normalization, num0, input_size, device, criterion)
        

        # ======== write to TB and stats-file
        # (saves both to tb event files and a separate dict called "stats") every epoch
        writer.add_scalars('acc', {
                                    'test': test_acc,
                                    'train_inloop': train_acc_inloop,
                                    'train': train_acc
                                    }, global_step=epoch, walltime=time.time()-start_time )
        writer.add_scalars('loss', {
                                    'test': test_loss,
                                    'train_inloop': train_loss_inloop, 
                                    'train': train_loss
                                    }, global_step=epoch, walltime=time.time()-start_time )
        
        # Early Stopping routine
        if not no_ES:
            test_acc_tracker.append(test_acc)
            _=test_acc_tracker.pop(0)
            prev_avg_acc=np.mean(test_acc_tracker[:patience])
            curr_avg_acc=np.mean(test_acc_tracker[patience:])
            if curr_avg_acc<prev_avg_acc and epoch>(2*patience):
                # terminate training
                print(f'* * * Early Stopping: epoch {epoch}')
                print(f'* current avg: {curr_avg_acc}')
                print(f'* previous avg: {prev_avg_acc}')
                print(f'(no improvement over past {patience} epochs)')
                break


        # remember best test acc and save checkpoint
        is_best= test_acc > best_test_acc
        best_test_acc= max(test_acc, best_test_acc)

        if is_best:
            state= {'epoch': epoch, 'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(), 'args': args}
            save_name= f'{metrics_savedir}_best'
            sshproc= save_checkpoint(state, save_name)

        if args.scheduler_on:
            scheduler.step()
        epoch+=1

    writer.close()  # close current event file
    
    # ========== save final model checkpoint =============
    # ====================================================
    state= {'epoch': epoch, 'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(), 'args': args}
    save_name = f'{metrics_savedir}_final'
    print(f'saving checkpoint as {save_name}')
    sshproc= save_checkpoint(state, save_name)
    sshproc.wait()



