from __future__ import print_function

import logging
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
import numpy as np

import torchvision
import torchvision.transforms as transforms
import random
import pandas as pd
import os
import sys
import time
import argparse
from pathlib import Path
import wandb
from os.path import join
from adabelief_pytorch import AdaBelief
from adamp import AdamP, SGDP
from train_utils_tbr import train, test, getNetwork, net_esd_estimator
import json
import torch_optimizer
from lars_optim import lars


parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training')
parser.add_argument('--lr',             type=float,  default=0.01,                         help='learning_rate')
parser.add_argument('--net-type',       type=str,    default='wide-resnet',                help='model')
parser.add_argument('--depth',          type=int,    default=28,                           help='depth of model')
parser.add_argument('--num-epochs',     type=int,    default=200,                          help='number of epochs')
parser.add_argument('--widen-factor',   type=float,    default=1,                           help='width of model')
parser.add_argument('--warmup-epochs',  type=int,    default=0) 
parser.add_argument('--dataset',        type=str,    default='cifar10',                    help='dataset = [cifar10/cifar100]')
parser.add_argument('--lr-sche',        type=str,    default='step',                   choices=['step', 'cosine', 'warmup_cosine'])
parser.add_argument('--weight-decay',   type=float,  default=1e-4) # 5e-4
parser.add_argument('--wandb-tag',      type=str,    default='')
parser.add_argument('--wandb-on',       type=str,    default='True')
parser.add_argument('--print-tofile',   type=str,    default='True')
parser.add_argument('--ckpt-path',      type=str,    default='')


parser.add_argument('--batch-size',   type=int,      default=128) # 5e-4
parser.add_argument('--datadir',        type=str,    default='',    help='directory of dataset')
parser.add_argument('--optim-type',     type=str,    default='SGD',                        help='type of optimizer')
parser.add_argument('--resume',         type=str,    default='',                           help='resume from checkpoint')
parser.add_argument('--seed',           type=int,    default=42) 
parser.add_argument('--ww-interval',    type=int,    default=10)
parser.add_argument('--epochs-to-save',  type=int,   nargs='+',  default=[])
parser.add_argument('--fix-fingers',     type=str,   default=None, help="xmin_peak")
parser.add_argument('--pl-package',     type=str,    default='powerlaw')
parser.add_argument('--filter-zeros',  type=str,   default='False')
parser.add_argument('--look-k',  type=int,       default=5)
parser.add_argument('--look-alpha',  type=float,   default=0.8)

# esd related parameters
parser.add_argument('--xmin-pos',           type=int,  default=2, help='xmin_index = size of eigs // xmin_pos')

args = parser.parse_args()

print(args)
print(f"--------------------> Main Baseline Compare OPT {args.optim_type}<--------------------")

def set_seed(seed=42):
    print(f"=====> Set the random seed as {seed}")
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    

def save_args_to_file(args, output_file_path):
    with open(output_file_path, "w") as output_file:
        json.dump(vars(args), output_file, indent=4)


# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
start_epoch = cf.start_epoch
set_seed(args.seed)


# Data Uplaod
print('\n[Phase 1] : Data Preparation')
if 'cifar' in args.dataset:
    print(f"prepare preprocessing, {args.dataset}")
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
    ]) # meanstd transformation

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
    ])
else:
    raise NotImplementedError

data_path = join(args.datadir, args.dataset)
if(args.dataset == 'cifar10'):
    print("| Preparing CIFAR-10 dataset...")
    sys.stdout.write("| ")
    trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=False, transform=transform_test)
    num_classes = 10
elif(args.dataset == 'cifar100'):
    print("| Preparing CIFAR-100 dataset...")
    sys.stdout.write("| ")
    trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=False, transform=transform_test)
    num_classes = 100


trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=6)
testloader = torch.utils.data.DataLoader(testset, batch_size=500, shuffle=False, num_workers=4)

Path(args.ckpt_path).mkdir(parents=True, exist_ok=True)


if args.print_tofile == 'True':
    # Open files for stdout and stderr redirection
    stdout_file = open(os.path.join(args.ckpt_path, 'stdout.log'), 'w')
    stderr_file = open(os.path.join(args.ckpt_path, 'stderr.log'), 'w')
    # Redirect stdout and stderr to the files
    sys.stdout = stdout_file
    sys.stderr = stderr_file

if args.wandb_on == 'True':
    wandb.init(config=args, 
            project='tbr_snr', 
            dir=args.ckpt_path,
            entity='weightwatcher_train',
            name=args.wandb_tag)

# Save the arguments to a file
save_args_to_file(args, join(args.ckpt_path, 'args.json'))


# Model
print('\n[Phase 2] : Model setup')
if args.resume:
    # Load checkpoint
    print('| Resuming from checkpoint...')
    net, file_name = getNetwork(args, num_classes)
    checkpoint = torch.load(args.resume, map_location='cpu')
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['test_acc']
    start_epoch = checkpoint['epoch']
    print(f"Loaded Epoch: {start_epoch} \n Test Acc: {best_acc:.3f} Train Acc: {checkpoint['train_acc']:.3f}")
else:
    print('| Building net type [' + args.net_type + ']...')
    net, file_name = getNetwork(args, num_classes)
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    best_acc = 0

if use_cuda:
    net.cuda()
    cudnn.benchmark = True
print(net)
criterion = nn.CrossEntropyLoss()

print('\n[Phase 3] : Training model')
print('| Training Epochs = ' + str(args.num_epochs))
print('| Initial Learning Rate = ' + str(args.lr))
print('| Optimizer = ' + str(args.optim_type))

test_acc, test_loss = test(epoch=0, net=net, testloader=testloader, criterion=criterion)
print(f"Reevaluated: Test Acc: {test_acc:.3f}, Test Loss: {test_loss:.3f}")

#######################ESD analysis###############################
##################################################################
print("####################Start ESD analysis###################")
Path(os.path.join(args.ckpt_path, 'stats')).mkdir(parents=True, exist_ok=True)
esd_start_time = time.time()
metrics = net_esd_estimator(net, 
                  EVALS_THRESH = 0.00001,
                  bins = 100,
                  fix_fingers=args.fix_fingers,
                  xmin_pos=args.xmin_pos, 
                  filter_zeros= args.filter_zeros == 'True')

estimated_time = time.time() - esd_start_time
print(f"-----> ESD estimation time: {estimated_time:.3f}")
# summary and submit to wandb
metric_summary = {}
for key in metrics:
    if key != 'eigs' and key != 'longname':
        metric_summary[key] = np.mean(metrics[key])
metric_summary.update({'test_acc':test_acc, 
          'test_loss': test_loss, 
          'epoch': 0})
if args.wandb_on == 'True':
    wandb.log(metric_summary)

# save metrics to disk and ESD
layer_stats=pd.DataFrame({key:metrics[key] for key in metrics if key!='eigs'})
layer_stats_origin = layer_stats.copy()
layer_stats_origin.to_csv(os.path.join(args.ckpt_path, 'stats',  f"origin_layer_stats_epoch_{0}.csv"))
np.save(os.path.join(args.ckpt_path, 'stats', 'esd_epoch_{0}.npy'), metrics)
##################################################################
if args.optim_type == 'SGD':
    optimizer = optim.SGD(net.parameters(), 
                        lr=args.lr,  
                        momentum=0.9, 
                        weight_decay=args.weight_decay)
    
elif args.optim_type == 'Adam':
    optimizer = optim.Adam(net.parameters(), 
                        lr=args.lr,  
                        weight_decay=args.weight_decay)
    
elif args.optim_type == 'SGDP':
    print(f"--------------------> Initialize the SGDP with lr:{args.lr}, wd:{args.weight_decay}")
    optimizer = SGDP( net.parameters(), 
                        lr=args.lr,  
                        momentum=0.9, 
                        weight_decay=args.weight_decay)

elif args.optim_type == 'Lookahead':
    print(f"--------------------> Initialize the Lookahead optimizer with k {args.look_k}  alpha {args.look_alpha}")
    optimizer = optim.SGD(net.parameters(), 
                        lr=args.lr,  
                        momentum=0.9, 
                        weight_decay=args.weight_decay)
    #α = 0.8 and k = 5   α = {0.2, 0.5, 0.8}    k = {5, 10}
    optimizer = torch_optimizer.Lookahead(optimizer, k=args.look_k, alpha=args.look_alpha)

elif args.optim_type == 'LARS':
    print(f"--------------------> Initialize the LARS optimizer with lr:{args.lr}, wd:{args.weight_decay}")
    optimizer = lars.LARS(net.parameters(), args.lr, weight_decay=args.weight_decay)

##################################################################

if args.lr_sche == 'step':
    lr_schedule = cf.stepwise_decay
elif args.lr_sche == 'cosine':
    lr_schedule = cf.cosine_decay
elif args.lr_sche == 'warmup_cosine':
    lr_schedule = cf.warmup_cosine_decay
else:
    raise NotImplementedError


elapsed_time = 0
training_stats = \
{'test_acc': [test_acc],
'test_loss': [test_loss],
'train_acc': [],
'train_loss': [],
'current_lr':[],
'schedule_next_lr':[]
}
untuned_lr = args.lr
is_current_best=False
for epoch in range(start_epoch, start_epoch+args.num_epochs):
    epoch_start_time = time.time()

    # this is current LR
    current_lr = untuned_lr
    print(f"##############Epoch {epoch}  current LR: {current_lr:.5f}################")

    # train and test
    train_acc, train_loss = train(epoch, net, args.num_epochs, trainloader, criterion, optimizer, args.optim_type)
    print("\n| Train Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, train_loss, train_acc))
    test_acc, test_loss = test(epoch, net, testloader, criterion)
    print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, test_loss, test_acc))

    # save in interval
    if epoch in args.epochs_to_save:
        state = {
            'net': net.state_dict(),
            'test_acc':test_acc,
            'test_loss':test_loss,
            'train_acc':train_acc,
            'train_loss':train_loss,
            'epoch':epoch
        }
        torch.save(state, join(args.ckpt_path, f'epoch_{epoch}.ckpt'))
    # save best
    if test_acc > best_acc:
        print('| Saving Best model')
        state = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'test_acc':test_acc,
            'best_acc': best_acc,
            'test_loss':test_loss,
            'train_acc':train_acc,
            'train_loss':train_loss,
            'epoch':epoch
        }
        best_acc = test_acc
        is_current_best=True
        torch.save(state, join(args.ckpt_path, f'epoch_best.ckpt'))
    else:
        is_current_best=False
    
    #######################ESD analysis###############################
    ##################################################################
    if epoch == 1 or epoch % args.ww_interval == 0:
        print("################ Start ESD analysis#############")
        esd_start_time = time.time()
        metrics = net_esd_estimator(net, 
                  EVALS_THRESH = 0.00001,
                  bins = 100,
                  fix_fingers=args.fix_fingers,
                  xmin_pos = args.xmin_pos,
                  filter_zeros= args.filter_zeros == 'True')
        
        metric_summary = {}
        for key in metrics:
            if key != 'eigs' and key != 'longname':
                metric_summary[key] = np.mean(metrics[key])

        layer_stats= pd.DataFrame({key:metrics[key] for key in metrics if key!='eigs'})
        # save metrics to disk and ESD
        layer_stats_origin = layer_stats.copy()
        layer_stats_origin.to_csv(os.path.join(args.ckpt_path, 'stats',  f"origin_layer_stats_epoch_{epoch}.csv"))
        np.save(os.path.join(args.ckpt_path, 'stats', f'esd_epoch_{epoch}.npy'), metrics)
        if is_current_best:
            np.save(os.path.join(args.ckpt_path, f'esd_best.npy'), metrics)

        esd_estimated_time = time.time() - esd_start_time
        print(f"-----> ESD estimation time: {esd_estimated_time:.3f}")

    else:
        metric_summary = {}

    ##################################################################
    # Reschedule the learning rate
    untuned_lr = lr_schedule(args.lr, epoch, args.num_epochs, warmup_epochs=args.warmup_epochs)
    print(f"------------>Rescheduled decayed LR: {untuned_lr:.5f}<--------------------")

    print("------------>  Schedule by default")
    for param_group in optimizer.param_groups:
        param_group['lr'] = untuned_lr

    train_summary = {'test_acc':test_acc, 
                    'test_loss': test_loss,
                    'train_acc': train_acc, 
                    'train_loss': train_loss,
                    'current_lr': current_lr,
                    'schedule_next_lr': untuned_lr,
                    'elapsed_time':elapsed_time}

    train_summary.update(metric_summary)

    if args.wandb_on == 'True':
        wandb.log(train_summary)
    training_stats['test_acc'].append(test_acc)
    training_stats['test_loss'].append(test_loss)
    training_stats['train_acc'].append(train_acc)
    training_stats['train_loss'].append(train_loss)
    training_stats['current_lr'].append(current_lr)
    training_stats['schedule_next_lr'].append(untuned_lr)

    np.save(join(args.ckpt_path, "training_stats.npy"), training_stats)
    epoch_time = time.time() - epoch_start_time
    elapsed_time += epoch_time
    print('| Elapsed time : %d:%02d:%02d'  %(cf.get_hms(elapsed_time)))
    print('--------------------> <--------------------')
    
if args.wandb_on == 'True':
    wandb.finish()

if args.print_tofile == 'True':
    # Close the files to flush the output
    stdout_file.close()
    stderr_file.close()