import argparse
import sys
import os
import numpy as np
import torch
import time
import data
import utils
import json
import sharpness
import sparse_sharpness
import mask_sharpness
import csv
import argparse
import torchattacks
import torch.nn.functional as F
import models


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', default=0, type=int)
    parser.add_argument('--dataset', default='cifar10', type=str)
    parser.add_argument('--model', default='resnet18', type=str)
    parser.add_argument('--model_path', type=str, help='model path')
    parser.add_argument('--n_eval', default=10000, type=int, help='#examples to evaluate on error')
    parser.add_argument('--bs', default=256, type=int, help='batch size for error computation')
    parser.add_argument('--n_eval_sharpness', default=1024, type=int, help='#examples to evaluate on sharpness')
    parser.add_argument('--bs_sharpness', default=128, type=int, help='batch size for sharpness experiments')
    parser.add_argument('--rho', default=0.1, type=float, help='L2 radius for sharpness')
    parser.add_argument('--step_size_mult', default=1.0, type=float, help='step size multiplier for sharpness')
    parser.add_argument('--n_iters', default=20, type=int, help='number of iterations for sharpness')
    parser.add_argument('--n_restarts', default=1, type=int, help='number of restarts for sharpness')
    parser.add_argument('--model_width', default=64, type=int, help='model width (# conv filters on the first layer for ResNets)')
    parser.add_argument('--sharpness_on_test_set', action='store_true', help='compute sharpness on the test set')
    parser.add_argument('--sharpness_rand_init', action='store_true', help='random initialization')
    parser.add_argument('--merge_bn_stats', action='store_true', help='merge BN means and variances to its learnable parameters')
    parser.add_argument('--no_grad_norm', action='store_true', help='no gradient normalization in APGD')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--algorithm', default='m_apgd_linf', choices=['avg_l2', 'avg_linf', 'm_apgd_l2', 'm_apgd_linf'], type=str)
    parser.add_argument('--log_folder', default='logs_eval', type=str)
    parser.add_argument('--adaptive', action='store_true')
    parser.add_argument('--normalize_logits', action='store_true')
    parser.add_argument('--data_augm_sharpness', action='store_true')
    parser.add_argument('--lwm', action='store_true')
    parser.add_argument('--rb', action='store_true')
    parser.add_argument('--sparse', action='store_true')
    parser.add_argument('--k', default=0.1, type=float)
    parser.add_argument('--task_mode', type=str, choices=("score_prune", "harp_prune", "score_finetune", "harp_finetune", "harp_finetune_lwm", "pretrain"), default="harp_prune")
    parser.add_argument('--csv_file', default="robustloss_results.csv")
    parser.add_argument('--dense', action='store_true', help='compute sharpness on pruned model with dense layer.')
    parser.add_argument('--masked', action='store_true', help='compute mask loss')


    return parser.parse_args()


start_time = time.time()
args = get_args()

n_cls = 10 if args.dataset != 'cifar100' else 100
sharpness_split = 'test' if args.sharpness_on_test_set else 'train'
assert args.n_eval_sharpness % args.bs_sharpness == 0, 'args.n_eval should be divisible by args.bs_sharpness'
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
np.set_printoptions(precision=4, suppress=True)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

loss_f = lambda logits, y: F.cross_entropy(logits, y, reduction='mean')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_normalization = False
model = models.__dict__[args.model](n_cls=10)
if args.sparse and not args.dense:
    model = models.__dict__[args.model](n_cls=10, task_mode=args.task_mode)
model_dict = torch.load('{}'.format(args.model_path), map_location=device)['state_dict']
model.load_state_dict({k: v for k, v in model_dict.items()}, strict=False)
model.eval()
model.to(device)
if args.sparse or args.masked:
    utils.set_prune_rate_model(model, args, device)
    if args.task_mode == 'harp_finetune_lwm' and 'pretrain' in args.model_path:
        utils.initialize_scaled_score(model, prune_reg='weight')
        

model = models.__dict__['LogitNormalizationWrapper'](model, normalize_logits=args.normalize_logits)

eval_train_batches = data.get_loaders(args.dataset, args.n_eval, args.bs, split='train', normalization=data_normalization, shuffle=False,
                                      data_augm=False, drop_last=False)
eval_test_batches = data.get_loaders(args.dataset, args.n_eval, args.bs, split='test', normalization=data_normalization, shuffle=False,
                                     data_augm=False, drop_last=False)

train_err, train_loss = utils.compute_robust_err(eval_train_batches, model, atk=None)

test_err, test_loss = utils.compute_robust_err(eval_test_batches, model, atk=None)


# get initial eval results
print('[train] err={:.2%} loss={:.5f}, [test] err={:.2%}, loss={:.4f}'.format(train_err, train_loss, test_err, test_loss))

# these batches are just batches from test or training data
batches_sharpness = data.get_loaders(args.dataset, args.n_eval_sharpness, args.bs_sharpness, normalization=data_normalization, split=sharpness_split, shuffle=False,
                                        data_augm=args.data_augm_sharpness, drop_last=False, randaug=args.data_augm_sharpness)

if not args.sparse and not args.masked:
    if args.algorithm == 'm_apgd_l2':
        sharpness_obj, sharpness_err, _, output = sharpness.eval_APGD_Rsharpness(
            model, batches_sharpness, loss_f, train_err, train_loss, device=device,
            rho=args.rho, n_iters=args.n_iters, n_restarts=args.n_restarts, step_size_mult=args.step_size_mult,
            rand_init=args.sharpness_rand_init, no_grad_norm=args.no_grad_norm,
            verbose=True, return_output=True, adaptive=args.adaptive, version='default', norm='l2')

    if args.algorithm == 'm_apgd_linf':
        sharpness_obj, sharpness_err, _, output = sharpness.eval_APGD_Rsharpness(
            model, batches_sharpness, loss_f, train_err, train_loss, device=device,
            rho=args.rho, n_iters=args.n_iters, n_restarts=args.n_restarts, step_size_mult=args.step_size_mult,
            rand_init=args.sharpness_rand_init, no_grad_norm=args.no_grad_norm,
            verbose=True, return_output=True, adaptive=args.adaptive, version='default', norm='linf')

    if args.algorithm == 'avg_l2':
        sharpness_obj, sharpness_err, _, output = sharpness.eval_average_Rsharpness(
            model, batches_sharpness, loss_f, device=device, rho=args.rho, n_iters=args.n_iters, return_output=True, adaptive=args.adaptive, norm='l2')

    if args.algorithm == 'avg_linf':
        sharpness_obj, sharpness_err, _, output = sharpness.eval_average_Rsharpness(
            model, batches_sharpness, loss_f, device=device, rho=args.rho, n_iters=args.n_iters, return_output=True, adaptive=args.adaptive, norm='linf')
elif args.sparse and not args.masked:
    if args.algorithm == 'm_apgd_linf':
        sharpness_obj, sharpness_err, _, output = sparse_sharpness.eval_APGD_Rsharpness(
            model, batches_sharpness, loss_f, train_err, train_loss, device=device,
            rho=args.rho, n_iters=args.n_iters, n_restarts=args.n_restarts, step_size_mult=args.step_size_mult,
            rand_init=args.sharpness_rand_init, no_grad_norm=args.no_grad_norm,
            verbose=True, return_output=True, adaptive=args.adaptive, version='default', norm='linf', args=args)
elif args.masked: 
    if args.algorithm == 'm_apgd_linf':
        sharpness_obj, sharpness_err, _, output = mask_sharpness.eval_APGD_Rsharpness(
            model, batches_sharpness, loss_f, train_err, train_loss, device=device,
            rho=args.rho, n_iters=args.n_iters, n_restarts=args.n_restarts, step_size_mult=args.step_size_mult,
            rand_init=args.sharpness_rand_init, no_grad_norm=args.no_grad_norm,
            verbose=True, return_output=True, adaptive=args.adaptive, version='default', norm='linf', args=args)

print('sharpness: obj={:.5f}, err={:.2%}'.format(sharpness_obj, sharpness_err)) 


### Save all the arguments, train_err, train_loss,test_err, test_loss, sharpness_obj, sharpness_err, sharpness_gradp_norm
checkpoint = dict([(arg, getattr(args, arg)) for arg in vars(args)])
# checkpoint['output'] = output
checkpoint['train_err'] = train_err
checkpoint['train_loss'] = train_loss
checkpoint['test_err'] = test_err
checkpoint['test_loss'] = test_loss
checkpoint['sharpness_obj'] = sharpness_obj 
checkpoint['sharpness_err'] = sharpness_err
checkpoint['time'] = (time.time() - start_time) / 60


# Your script logic to compute `objc`
objc = "{:.5f} (e{:.1%})".format(sharpness_obj, sharpness_err)
csv_file = args.csv_file
try:
    with open(csv_file, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([objc])
except Exception as e:
    print("Error occurred while writing to CSV:", e)

path = utils.get_path(args, args.log_folder)
if not os.path.exists(args.log_folder):
    os.makedirs(args.log_folder)
with open(path, 'w') as outfile:
    json.dump(checkpoint, outfile)

print('Done in {:.2f}m'.format((time.time() - start_time) / 60))