from __future__ import print_function


import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import config as cf
import numpy as np
import torchvision.datasets as datasets


import torchvision.transforms as transforms
import random
import pandas as pd
import wandb
import os
import sys
import time
import argparse
from pathlib import Path
from os.path import join

from train_utils_tbr import train, test, getNetwork, net_esd_estimator, get_layer_temps
from sgdsnr import SGDSNR

#import weightwatcher as ww
import json

# parse arguments
# default hypermarameter: epoch 200   lr =0.1  batch size 128    cosine decay   weight decay 5e-4
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

parser.add_argument('--lr',             type=float,  default=0.01,                         help='learning_rate')
parser.add_argument('--net-type',       type=str,    default='resnet_tiny_imagenet',                     help='model')
parser.add_argument('--depth',          type=int,    default=18,                           help='depth of model')
parser.add_argument('--num-epochs',     type=int,    default=200,                          help='number of epochs')
parser.add_argument('--widen-factor',   type=int,    default=10,                           help='width of model')
parser.add_argument('--warmup-epochs',  type=int,    default=0) 
parser.add_argument('--dataset',        type=str,    default='tiny-imagenet',                    help='dataset = [cifar10/cifar100]')
parser.add_argument('--lr-sche',        type=str,    default='step',                   choices=['step', 'cosine', 'warmup_cosine'])
parser.add_argument('--weight-decay',   type=float,  default=1e-4) # 5e-4
parser.add_argument('--wandb-tag',      type=str,    default='')
parser.add_argument('--wandb-on',       type=str,    default='True')
parser.add_argument('--print-tofile',   type=str,    default='True')
parser.add_argument('--ckpt-path',      type=str,    default='')


parser.add_argument('--batch-size',   type=int,      default=128)
parser.add_argument('--datadir',        type=str,    default='',          help='directory of dataset')
parser.add_argument('--optim-type',     type=str,    default='SGD',                        help='type of optimizer')
parser.add_argument('--resume',         type=str,    default='',                           help='resume from checkpoint')
parser.add_argument('--seed',           type=int,    default=42) 
parser.add_argument('--ww-interval',    type=int,    default=1)
parser.add_argument('--epochs-to-save',  type=int,   nargs='+',  default=[])
parser.add_argument('--fix-fingers',     type=str,   default=None, help="xmin_peak")
parser.add_argument('--pl-package',     type=str,    default='powerlaw')

# temperature balance related 
parser.add_argument('--remove-last-layer',   type=str,    default='True',  help='if remove the last layer')
parser.add_argument('--remove-first-layer',  type=str,   default='True',  help='if remove the last layer')
parser.add_argument('--metric',                type=str,    default='alpha',  help='ww metric')
parser.add_argument('--temp-balance-lr',       type=str,    default='',       help='use tempbalance for learning rate')
parser.add_argument('--batchnorm',             type=str,    default='False')
parser.add_argument('--lr-min-ratio',          type=float,  default=0.7)
parser.add_argument('--lr-slope',           type=float,  default=0.6)
parser.add_argument('--xmin-pos',           type=float,  default=2, help='xmin_index = int(round(size of eigs / xmin_pos))')
parser.add_argument('--lr-min-ratio-stage2',   type=float,  default=1)
# spectral regularization related
parser.add_argument('--sg',                 type=float, default=0.01, help='spectrum regularization')
parser.add_argument('--stage-epoch',        type=int, default=0,  help='stage_epoch')
parser.add_argument('--filter-zeros',  type=str,   default='False')

args = parser.parse_args()

print(args)
print(f"--------------------> TIN TB {args.net_type} <---------------------")
def set_seed(seed=42):
    print(f"=====> Set the random seed as {seed}")
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    

def save_args_to_file(args, output_file_path):
    with open(output_file_path, "w") as output_file:
        json.dump(vars(args), output_file, indent=4)

# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
start_epoch = cf.start_epoch
set_seed(args.seed)





# load data
print('\n[Phase 1] : Data Preparation')

# val set cannot be load with Imagefolder
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]),
}
args.datadir = os.path.join(args.datadir, 'tiny-imagenet-200')

trainset = datasets.ImageFolder(os.path.join(args.datadir, 'train'), data_transforms['train'])
testset = datasets.ImageFolder(os.path.join(args.datadir, 'val'), data_transforms['val'])

train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                             shuffle=True, num_workers=6)
test_loader = torch.utils.data.DataLoader(testset, batch_size=200,
                                                shuffle=False, num_workers=6)

# num of classes for tiny imagenet
num_classes = 200

# make checkpoint path
Path(args.ckpt_path).mkdir(parents=True, exist_ok=True)

if args.print_tofile == 'True':
    # Open files for stdout and stderr redirection
    stdout_file = open(os.path.join(args.ckpt_path, 'stdout.log'), 'w')
    stderr_file = open(os.path.join(args.ckpt_path, 'stderr.log'), 'w')
    # Redirect stdout and stderr to the files
    sys.stdout = stdout_file
    sys.stderr = stderr_file

if args.wandb_on == 'True':
    wandb.init(config=args, 
            project='tbr_snr', 
            dir=args.ckpt_path,
            entity='weightwatcher_train',
            name=args.wandb_tag)

# Save the arguments to a file
save_args_to_file(args, join(args.ckpt_path, 'args.json'))

# set up the model
print('\n[Phase 2] : Model setup')
if args.resume:
    # Load checkpoint
    print('| Resuming from checkpoint...')
    net, file_name = getNetwork(args, num_classes)
    checkpoint = torch.load(args.resume, map_location='cpu')
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['test_acc']
    start_epoch = checkpoint['epoch']
    print(f"Loaded Epoch: {start_epoch} \n Test Acc: {best_acc:.3f} Train Acc: {checkpoint['train_acc']:.3f}")
else:
    print('| Building net type [' + args.net_type + ']...')
    net, file_name = getNetwork(args, num_classes)
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    best_acc = 0

if use_cuda:
    net.cuda()
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()

# train model
print('\n[Phase 3] : Training model')
print('| Training Epochs = ' + str(args.num_epochs))
print('| Initial Learning Rate = ' + str(args.lr))
print('| Optimizer = ' + str(args.optim_type))

test_acc, test_loss = test(epoch=0, net=net, testloader=test_loader, criterion=criterion)
print(f"Reevaluated: Test Acc: {test_acc:.3f}, Test Loss: {test_loss:.3f}")

#######################ESD analysis###############################
##################################################################
print("####################Start ESD analysis###################")
Path(os.path.join(args.ckpt_path, 'stats')).mkdir(parents=True, exist_ok=True)
esd_start_time = time.time()
metrics = net_esd_estimator(net, 
                  EVALS_THRESH = 0.00001,
                  bins = 100,
                  fix_fingers=args.fix_fingers,
                  xmin_pos=args.xmin_pos,
                  filter_zeros = args.filter_zeros=='True')
estimated_time = time.time() - esd_start_time
print(f"-----> ESD estimation time: {estimated_time:.3f}")
# summary and submit to wandb
metric_summary = {}
for key in metrics:
    if key != 'eigs' and key != 'longname':
        metric_summary[key] = np.mean(metrics[key])
metric_summary.update({'test_acc':test_acc, 
          'test_loss': test_loss, 
          'epoch': 0})
if args.wandb_on == 'True':
    wandb.log(metric_summary)

# save metrics to disk and ESD
layer_stats=pd.DataFrame({key:metrics[key] for key in metrics if key!='eigs'})
layer_stats_origin = layer_stats.copy()
layer_stats_origin.to_csv(os.path.join(args.ckpt_path, 'stats',  f"origin_layer_stats_epoch_{0}.csv"))
np.save(os.path.join(args.ckpt_path, 'stats', 'esd_epoch_{0}.npy'), metrics)

######################  TBR scheduling ##########################
##################################################################
if args.temp_balance_lr != 'None':
    print("################## Enable temp balance ##############")

    if args.remove_first_layer == 'True':
        print("remove first layer of alpha<---------------------")
        layer_stats = layer_stats.drop(labels=0, axis=0)
        # index must be reset otherwise may delete the wrong row 
        layer_stats.index = list(range(len(layer_stats[args.metric])))
    if args.remove_last_layer == 'True':
        print("remove last layer of alpha<---------------------")
        layer_stats = layer_stats.drop(labels=len(layer_stats) - 1, axis=0)
        # index must be reset otherwise may delete the wrong row 
        layer_stats.index = list(range(len(layer_stats[args.metric])))
    
    metric_scores = np.array(layer_stats[args.metric])
    #args, temp_balance, n_alphas, epoch_val
    scheduled_lr = get_layer_temps(args, temp_balance=args.temp_balance_lr, n_alphas=metric_scores, epoch_val=args.lr)
    layer_stats['scheduled_lr'] = scheduled_lr
    #print(layer_stats.to_string())

    # these params should be tuned
    layer_name_to_tune = list(layer_stats['longname'])
    all_params = []
    params_to_tune_ids = []

    # these params should be tuned
    for name, module in net.named_modules():
        # these are the conv layers analyzed by the weightwatcher
        if name in layer_name_to_tune:
            params_to_tune_ids += list(map(id, module.parameters()))
            scheduled_lr = layer_stats[layer_stats['longname'] == name]['scheduled_lr'].item()
            all_params.append({'params': module.parameters(), 'lr': scheduled_lr})
        # decide should we tune the batch norm accordingly,  is this layer batchnorm and does its corresponding conv in layer_name_to_tune
        elif args.batchnorm == 'True' \
                and isinstance(module, nn.BatchNorm2d) \
                    and name.replace('bn', 'conv') in layer_name_to_tune:
            params_to_tune_ids += list(map(id, module.parameters()))
            scheduled_lr = layer_stats[layer_stats['longname'] == name.replace('bn', 'conv')]['scheduled_lr'].item()
            all_params.append({'params': module.parameters(), 'lr': scheduled_lr})
        # another way is to add a else here and append params with args.lr

    # those params are untuned
    untuned_params = filter(lambda p: id(p) not in params_to_tune_ids, net.parameters())
    all_params.append({'params': untuned_params, 'lr': args.lr}) 
    # create optimizer
    optimizer = SGDSNR(all_params, 
                            momentum=0.9, 
                            weight_decay=args.weight_decay, 
                            spectrum_regularization=args.sg,
                            stage_epoch=args.stage_epoch,
                            epoch=1)
else:
    print("-------------> Disable temp balance")
    optimizer = SGDSNR(net.parameters(), lr=args.lr,
                            momentum=0.9, 
                            weight_decay=args.weight_decay, 
                            spectrum_regularization=args.sg,
                            stage_epoch=args.stage_epoch,
                            epoch=1) 
##################################################################


# save scheduled learning rate 
layer_stats.to_csv(os.path.join(args.ckpt_path, 'stats', f"layer_stats_with_lr_epoch_{0}.csv"))

#########################################################################

if args.lr_sche == 'step':
    lr_schedule = cf.stepwise_decay
elif args.lr_sche == 'cosine':
    lr_schedule = cf.cosine_decay
elif args.lr_sche == 'warmup_cosine':
    lr_schedule = cf.warmup_cosine_decay
else:
    raise NotImplementedError

elapsed_time = 0
training_stats = \
{'test_acc': [test_acc],
'test_loss': [test_loss],
'train_acc': [],
'train_loss': [],
'current_lr':[],
'schedule_next_lr':[]
}
untuned_lr = args.lr
is_current_best=False

for epoch in range(start_epoch, start_epoch+args.num_epochs):
    epoch_start_time = time.time()
    # consider use another (maybe bigger) minimum learning rate in tbr
    if args.stage_epoch > 0 and epoch >= args.stage_epoch:
        print("------> Enter the second stage!!!!!!!!!!")
        args.lr_min_ratio = args.lr_min_ratio_stage2
    else:
        pass

    # this is current LR
    current_lr = untuned_lr
    print(f"##############Epoch {epoch}  current LR: {current_lr:.5f}################")

    # train and test
    train_acc, train_loss = train(epoch, net, args.num_epochs, train_loader, criterion, optimizer)
    print("\n| Train Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, train_loss, train_acc))
    test_acc, test_loss = test(epoch, net, test_loader, criterion)
    print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, test_loss, test_acc))

    # save in interval
    if epoch in args.epochs_to_save:
        state = {
            'net': net.state_dict(),
            'test_acc':test_acc,
            'test_loss':test_loss,
            'train_acc':train_acc,
            'train_loss':train_loss,
            'epoch':epoch
        }
        torch.save(state, join(args.ckpt_path, f'epoch_{epoch}.ckpt'))
    # save best
    if test_acc > best_acc:
        print('| Saving Best model')
        state = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'test_acc':test_acc,
            'best_acc': best_acc,
            'test_loss':test_loss,
            'train_acc':train_acc,
            'train_loss':train_loss,
            'epoch':epoch
        }
        best_acc = test_acc
        is_current_best=True
        torch.save(state, join(args.ckpt_path, f'epoch_best.ckpt'))
    else:
        is_current_best=False
    
    #######################ESD analysis###############################
    ##################################################################
    if epoch == 1 or epoch % args.ww_interval == 0:
        print("################ Start ESD analysis#############")
        esd_start_time = time.time()
        metrics = net_esd_estimator(net, 
                  EVALS_THRESH = 0.00001,
                  bins = 100,
                  fix_fingers=args.fix_fingers,
                  xmin_pos=args.xmin_pos,
                  filter_zeros=args.filter_zeros=='True')
        
        metric_summary = {}
        for key in metrics:
            if key != 'eigs' and key != 'longname':
                metric_summary[key] = np.mean(metrics[key])

        layer_stats= pd.DataFrame({key:metrics[key] for key in metrics if key!='eigs'})
        # save metrics to disk and ESD
        layer_stats_origin = layer_stats.copy()
        layer_stats_origin.to_csv(os.path.join(args.ckpt_path, 'stats',  f"origin_layer_stats_epoch_{epoch}.csv"))
        np.save(os.path.join(args.ckpt_path, 'stats', f'esd_epoch_{epoch}.npy'), metrics)
        if is_current_best:
            np.save(os.path.join(args.ckpt_path, f'esd_best.npy'), metrics)

        esd_estimated_time = time.time() - esd_start_time
        print(f"-----> ESD estimation time: {esd_estimated_time:.3f}")

    else:
        metric_summary = {}

    ##################################################################
    # Reschedule the learning rate
    untuned_lr = lr_schedule(args.lr, epoch, args.num_epochs, warmup_epochs=args.warmup_epochs)
    print(f"------------>Rescheduled decayed LR: {untuned_lr:.5f}<--------------------")

    if args.temp_balance_lr != 'None':
        ######################  TBR scheduling ##########################
        ##################################################################

        print("############### Schedule by Temp Balance###############")
        assert len(metric_summary) > 0, "in TBR, every epoch should has an updated metric summary"
        if args.remove_first_layer == 'True':
            print('remove first layer <--------------------')
            layer_stats = layer_stats.drop(labels=0, axis=0)
            # index must be reset otherwise next may delete the wrong row 
            layer_stats.index = list(range(len(layer_stats[args.metric])))
        if args.remove_last_layer == 'True':
            print('remove last layer <--------------------')
            layer_stats = layer_stats.drop(labels=len(layer_stats) - 1, axis=0)
            # index must be reset otherwise may delete the wrong row 
            layer_stats.index = list(range(len(layer_stats[args.metric])))

        metric_scores = np.array(layer_stats[args.metric])
        scheduled_lr = get_layer_temps(args, args.temp_balance_lr, metric_scores, untuned_lr)
        layer_stats['scheduled_lr'] = scheduled_lr
        #print(layer_stats.to_string())
        layer_name_to_tune = list(layer_stats['longname'])
        all_params_lr = []
        params_to_tune_ids = []
        c = 0
        for name, module in net.named_modules():
            if name in layer_name_to_tune:
                params_to_tune_ids += list(map(id, module.parameters()))
                scheduled_lr = layer_stats[layer_stats['longname'] == name]['scheduled_lr'].item()
                all_params_lr.append(scheduled_lr)
                c = c + 1
            elif args.batchnorm == 'True' \
                and isinstance(module, nn.BatchNorm2d) \
                    and name.replace('bn', 'conv') in layer_name_to_tune:
                params_to_tune_ids += list(map(id, module.parameters()))
                scheduled_lr = layer_stats[layer_stats['longname'] == name.replace('bn', 'conv')]['scheduled_lr'].item()
                all_params_lr.append(scheduled_lr)
                c = c + 1

        layer_stats.to_csv(os.path.join(args.ckpt_path, 'stats', f"layer_stats_with_lr_epoch_{epoch}.csv"))
        for index, param_group in enumerate(optimizer.param_groups):
            param_group['epoch'] = param_group['epoch'] + 1
            if index <= c - 1:
                param_group['lr'] = all_params_lr[index]
            else:
                param_group['lr'] = untuned_lr
    ##################################################################
    ##################################################################
    else:
        print("------------>  Schedule by default")
        for param_group in optimizer.param_groups:
            param_group['epoch'] = param_group['epoch'] + 1
            param_group['lr'] = untuned_lr

    train_summary = {'test_acc':test_acc, 
                    'test_loss': test_loss,
                    'train_acc': train_acc, 
                    'train_loss': train_loss,
                    'current_lr': current_lr,
                    'schedule_next_lr': untuned_lr,
                    'elapsed_time':elapsed_time}

    train_summary.update(metric_summary)

    if args.wandb_on == 'True':
        wandb.log(train_summary)
    training_stats['test_acc'].append(test_acc)
    training_stats['test_loss'].append(test_loss)
    training_stats['train_acc'].append(train_acc)
    training_stats['train_loss'].append(train_loss)
    training_stats['current_lr'].append(current_lr)
    training_stats['schedule_next_lr'].append(untuned_lr)

    np.save(join(args.ckpt_path, "training_stats.npy"), training_stats)
    epoch_time = time.time() - epoch_start_time
    elapsed_time += epoch_time
    print('| Elapsed time : %d:%02d:%02d'  %(cf.get_hms(elapsed_time)))
    print('--------------------> <--------------------')
    
if args.wandb_on == 'True':
    wandb.finish()

if args.print_tofile == 'True':
    # Close the files to flush the output
    stdout_file.close()
    stderr_file.close()