import argparse
import os
import sys
import random
import shutil
import time
import warnings
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from . import models
from .allutils import get_dataset, get_model, get_optimizer, get_scheduler
from .allutils import Recorder, LossTracker
from .allutils import get_num_W_tot, get_num_W, get_tensor_dims, get_ntf, get_lname_for_statedict, make_smask_for_layer, get_fanin, adjust_layer_init

parser = argparse.ArgumentParser(description='PyTorch Training')

parser.add_argument('--bucket', default='bucket2',
                    help='bucket name (default: bucket2)')
parser.add_argument('--datadir', default='imagenet',
                    help='path to dataset (default: imagenet)')
parser.add_argument('--arch', metavar='ARCH', default='resnet18',
                    help='model architecture: (default: resnet18)')

parser.add_argument('--noc1_base', default=64, type=int, 
                    help='number of output channels in the very 1st layer (conv1) IN THE BASELINE MODEL (default: 64)')
parser.add_argument('--noc1', default=64, type=int, 
                    help='number of output channels in the very 1st layer (conv1) (default: 64)')


parser.add_argument('--dataset', default='imagenet', type=str,
                    help='dataset')
parser.add_argument('--workers', default=32, type=int,
                    help='number of data loading workers, total (default: 32)')
parser.add_argument('--epochs', default=150, type=int,
                    help='number of total epochs to run')
parser.add_argument('--batchsize', default=1024, type=int,
                    help='mini-batch size (default: 1024), this is the total')
parser.add_argument('--optimizer', default="sgd", type=str,
                    help='optimizer')
parser.add_argument('--scheduler', default="cosine", type=str,
                    help='lr scheduler')
parser.add_argument('--lr', default=0.1, type=float,
                    help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', default=1e-4, type=float,
                    help='weight decay (default: 1e-4)')
parser.add_argument('--printfreq', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--evaluate', action='store_true',
                    help='evaluate model on validation set (no training)')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--half', default=False, action='store_true',
                    help='training with half precision')
#parser.add_argument('--logdir', default='', type=str,
#                    help='prefix to use when saving files') # compose logdir from args instead
parser.add_argument('--saveparam', default='optimizer', type=str,
                    help='param names and values to use when saving files')
parser.add_argument('--init_run', default=False, action='store_true',
                    help='init run creates and saves model and smask')
parser.add_argument('--adjust_init_for_sparse', default=True,
                    help='adjust initialization of sparse layers (reduced effective fan-in)')
parser.add_argument('--io_only', default=False, action='store_true', 
                    help='sparsify conv layers along IO dims only')
args = parser.parse_args()



def main():

  ltypes= ['Linear', 'Conv2d'] # layer/module types to sparsify

  set_seed(args.seed)

  # === initiate a recorder for saving and loading stats and checkpoints
  logdir= f'{args.dataset}_{args.arch}_{args.noc1_base}_{args.noc1}'
  print(f'\n>> Initializing recorder with out_dir={logdir}')
  rc = Recorder(out_dir=logdir, bucket_dir=args.bucket)
  
  local_ckpt_savedir= f'checkpoints/{logdir}'


  # ======================== init run start ==========================
  # prepare model and smask, save them, and quit main

  if args.init_run:
    print(f'\n>> Entered init_run conditional block! (args.init_run={args.init_run})')
    model, smask = prep_model_and_smask(ltypes)
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = get_optimizer(args.optimizer, model.parameters(), args.lr, args.momentum, args.wd)
    scheduler = get_scheduler(args.scheduler, optimizer, num_epochs=args.epochs)
    
    start_epoch= 0
    val_acc1= 0

    # === save smask (dict)
    torch.save(smask, f'{local_ckpt_savedir}/smask.pt')

    # === save model
    rc.save_full_checkpoint(model, optimizer, scheduler, args, start_epoch, val_acc1)
    
    fh= open(f'{local_ckpt_savedir}/model_printout.txt', 'w')
    print(model, file=fh) 
    fh.close()

    sys.exit() # quit main
  # ======================== init run end ==========================


  # === create datasets and dataloaders
  tr_set  = get_dataset(args.dataset, args.datadir, 'train')
  val_set = get_dataset(args.dataset, args.datadir, 'val')

  train_loader = torch.utils.data.DataLoader(tr_set, batch_size=args.batchsize,
                      shuffle=True, num_workers=args.workers, pin_memory=True)
  val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batchsize,
                      shuffle=False, num_workers=args.workers, pin_memory=True)

  # === setup model
  model = get_model(args, args.noc1)

  # === define loss function (criterion), optimizer and scheduler
  criterion = nn.CrossEntropyLoss().cuda()
  optimizer = get_optimizer(args.optimizer, model.parameters(), args.lr, args.momentum, args.wd)
  scheduler = get_scheduler(args.scheduler, optimizer, num_epochs=args.epochs)
  
  start_epoch= 0

  # === load model checkpoint if resuming
  if args.resume:
    print(f'\n>> Entered the args.resume conditional block! (args.resume={args.resume})')
    
    print('* * * Loading smask now...')
    smask= load_smask(local_ckpt_savedir)
    
    print('* * * Loading checkpoint now...')
    start_epoch= rc.resume_full_checkpoint(args.resume, model, optimizer, scheduler)


  # === collect model properties required for sparsification
  sparse= args.noc1 > args.noc1_base
  tensor_dims = get_tensor_dims(model, ltypes) # dimensions of each layer tensor
  num_W = get_num_W(tensor_dims) # num weights of each layer (prod of corresp tensor dims)
  num_W_tot= sum(num_W.values()) # total number of weights in current model
  num_W_tot_base= get_num_W_tot(args, args.noc1_base, ltypes) # total number of weights in baseline
  num_to_freeze_tot= int(num_W_tot-num_W_tot_base) # total number of weights to freeze
  lnames_sorted= sorted(num_W, key=num_W.get, reverse=True) # layer names sorted by size
  num_to_freeze= get_ntf(num_to_freeze_tot, num_W, tensor_dims, lnames_sorted, args.io_only)
  num_layers_to_sparsify= sum(num_to_freeze>0) # total number of layers that are to be sparse

  # ===== check that layers are masked correctly
  for lind in range(num_layers_to_sparsify):
    lname= lnames_sorted[lind]
    lname_for_statedict= get_lname_for_statedict(lname)

    num_zero=torch.sum(model.state_dict()[lname_for_statedict]==0)
    assert num_zero>=num_to_freeze[lind], f"(!) Error: layer {lname} not sparsified properly!"



  # ======================== training start ==========================

  if args.evaluate: # evaluate only
    validate(val_loader, model, criterion)
  else: # train and validate
    for epoch in range(start_epoch, args.epochs):
      tr_loss, tr_acc1, tr_acc5 = train(train_loader, model, criterion, optimizer, epoch, 
        sparse, smask, num_layers_to_sparsify, lnames_sorted, num_to_freeze)
      val_loss, val_acc1, val_acc5 = validate(val_loader, model, criterion)

      rc.add_losses(tr_loss, tr_acc1, tr_acc5, val_loss, val_acc1, val_acc5, global_step=epoch+1)
      rc.save_full_checkpoint(model, optimizer, scheduler, args, epoch+1, val_acc1)

      scheduler.step()
  # ======================== training end ==========================

  rc.close()



def load_smask(savedir):
  
  smask= torch.load(f'{savedir}/smask.pt')
  print(f'--> Loaded smask from {savedir}!')

  # putting all on GPU
  for k, v in smask.items():
      if v is not None:
          smask[k]=v.cuda()

  return smask


def prep_model_and_smask(ltypes):
  """ setup model, create sparsity mask (smask) & apply to model; 
  then save model, smask, and path to files """

  sparse= args.noc1 > args.noc1_base
  adjust_init_for_sparse= not args.adjust_init_for_sparse
  
  # total number of weights in the _baseline_ model
  num_W_tot_base= get_num_W_tot(args, args.noc1_base, ltypes)


  # ===== setup model
  model = get_model(args, args.noc1)

  # ===== collect model properties required for sparsification
  tensor_dims = get_tensor_dims(model, ltypes) # dimensions of each layer tensor
  num_W = get_num_W(tensor_dims) # num weights of each layer (prod of corresp tensor dims)
  
  num_W_tot= sum(num_W.values()) # total number of weights in current model
  num_to_freeze_tot= int(num_W_tot-num_W_tot_base) # total number of weights to freeze
  
  ctvt= num_W_tot_base/num_W_tot # connectivity

  lnames_sorted= sorted(num_W, key=num_W.get, reverse=True) # layer names sorted by size

  # init smask (need smask value to be None for layers that are not sparsified)
  smask= {lname: None for lname in lnames_sorted}

  # a list containing num_W-to-freeze for each layer (in sorted order)
  num_to_freeze= get_ntf(num_to_freeze_tot, num_W, tensor_dims, lnames_sorted, args.io_only)

  num_layers_to_sparsify = sum(num_to_freeze>0) # total number of layers that are to be sparse

  ll_str = '\n'.join([lnames_sorted[lind] for lind in range(num_layers_to_sparsify)])
  print(f'ctvt = {ctvt}\n\n{num_layers_to_sparsify} layers to sparsify:\n{ll_str}')


  for lind in range(num_layers_to_sparsify): # create and apply smask for each layer
    
    # ===== layer properties
    lname= lnames_sorted[lind]
    lsize= num_W[lname]
    ldims= tensor_dims[lname]
    lntf = num_to_freeze[lind]
    lname_for_statedict= get_lname_for_statedict(lname) # convert layer name to match layer key in model.state_dict()

    # ===== make smask for current layer
    print(f'\n>> Making smask for layer {lname} ({lntf} to freeze) ...')
    smask[lname]= make_smask_for_layer( lsize, lntf, ldims, args.io_only)
    num_1_in_smask= torch.sum(smask[lname]) # number of ones in smask (corresp to num weights to freeze) in given layer

    if args.io_only:
      kernel_size= np.prod(ldims[-2:]) if len(ldims)==4 else 1
      assert kernel_size*num_1_in_smask==lntf, f"(!) smask is wrong! {kernel_size*num_1_in_smask} {lntf}"
    else:
      assert num_1_in_smask==lntf, f"(!) smask is wrong! {num_1_in_smask} {lntf}"        
    
    # ===== adjust initialization values in sparse layer
    if adjust_init_for_sparse:
      print(f'\n>> Adjusting init for sparse layer {lname} ...')
      adjust_layer_init(model, lname, lname_for_statedict, ldims, lntf, lsize)

    # ===== apply smask to layer
    print(f'\n>> Applying smask ({lntf} to freeze) to layer {lname} with dims {ldims} ...')

    with torch.no_grad():
      if args.io_only: # if sparsity along IO dims only
        if len(ldims)==4:
          model.state_dict()[lname_for_statedict].masked_fill_(smask[lname].unsqueeze(2).unsqueeze(3), 0)
        elif len(ldims)==2:
          model.state_dict()[lname_for_statedict].masked_fill_(smask[lname], 0)
        else:
          print('(!) Error: tensor dims len is not 2 and not 4!')
      else:
        model.state_dict()[lname_for_statedict].masked_fill_(smask[lname], 0)
      # check
      num_zero= torch.sum(model.state_dict()[lname_for_statedict]==0)
      assert num_zero>=lntf, f"(!) Error: layer {lname} not sparsified properly!"

  return model, smask


def train(train_loader, model, criterion, optimizer, epoch, sparse, smask, num_layers_to_sparsify, lnames_sorted, num_to_freeze):
  # switch to train mode
  model.train()
  tracker = LossTracker(len(train_loader), f'Epoch: [{epoch}]', args.printfreq)
  for i, (images, target) in enumerate(train_loader):

    images, target = cuda_transfer(images, target)
    output= model(images)
    loss  = criterion(output, target)


    optimizer.zero_grad()
    loss.backward()  # compute gradients


    if sparse:
      for lind in range(num_layers_to_sparsify):
        lname= lnames_sorted[lind]
        lname_for_statedict= get_lname_for_statedict(lname)
        for n,layer in model.named_parameters():
          if n==lname_for_statedict:
            if args.io_only:
                if len(tensor_dims[lname])==4:
                    layer.grad.masked_fill_(smask[lname].unsqueeze(2).unsqueeze(3), 0)
                elif len(tensor_dims[lname])==2:
                    layer.grad.masked_fill_(smask[lname], 0)
                else:
                    print('(!) Error: tensor dims len is not 2 and not 4!')
            else:
                layer.grad.masked_fill_(smask[lname], 0)
            num_zero= torch.sum( layer.grad==0 )
            assert num_zero>=num_to_freeze[lind], f"(!) Error: grad in {n} not sparsified properly!"
    optimizer.step() # do SGD step

    #tracking acc and loss
    tracker.update(loss, output, target)
    tracker.display(i)

  return tracker.losses.avg, tracker.top1.avg, tracker.top5.avg



def validate(val_loader, model, criterion):
  # switch to evaluate mode
  model.eval()
  with torch.no_grad():
    tracker = LossTracker(len(val_loader), f'Test', args.printfreq)
    for i, (images, target) in enumerate(val_loader):

      images, target = cuda_transfer(images, target)
      output = model(images)
      loss = criterion(output, target)

      tracker.update(loss, output, target)
      tracker.display(i)

  return tracker.losses.avg, tracker.top1.avg, tracker.top5.avg



def set_seed(seed=None):
  if seed is not None:
      random.seed(args.seed)
      torch.manual_seed(args.seed)
      torch.backends.cudnn.deterministic = True
      warnings.warn('You have chosen to seed training. '
                    'This will turn on the CUDNN deterministic setting, '
                    'which can slow down your training considerably! '
                    'You may see unexpected behavior when restarting '
                    'from checkpoints.')

def gen_comment(args):
  param_list = args.saveparam.split(',')
  comment = ''
  for i, p in enumerate(param_list):
    if i > 0:
      comment += '-'
    if len(p) > 0:
      comment += f'{p}_{args.__dict__[p]}'
  return comment


def cuda_transfer(images, target):
  images = images.cuda(non_blocking=True)
  target = target.cuda(non_blocking=True)
  if args.half: images = images.half()
  return images, target


if __name__ == '__main__':
    main()

