import os
import ot
import argparse
import random
import pandas as pd
from torchvision.models import resnet152, ResNet152_Weights
from src.utils import *
from src.lp_robust_cp import LPRobustCP
from src.fdiv_robust_cp import FDivRobustCP
ImageFile.LOAD_TRUNCATED_IMAGES = True

# specify device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# argument parser
parser = argparse.ArgumentParser('Robust-CP')
parser.add_argument('--num_trials',     type=int,   default=20,   help="number of experiment runs")
parser.add_argument('--alpha',          type=float, default=0.1,  help="user prescribed confidence (1-alpha)")
parser.add_argument('--cal_ratio',      type=float, default=0.02, help="percent of data used for calibration")
parser.add_argument('--batch_size',     type=int,   default=1024, help="batch size for loading data")
parser.add_argument('--fdiv_radius',    type=float, default=2.5,  help="radius for f-divergence ball")
parser.add_argument('--corrupt_ratio',  type=float, default=0.05, help="percent of data label being rolled")
parser.add_argument('--noise_upper',    type=float, default=1.,   help="std used for noising images")
parser.add_argument('--noise_lower',    type=float, default=-1.,  help="std used for noising images")
parser.add_argument('--rho_est',        type=float, default=-1.,  help="estimated rho")
parser.add_argument('--eps_est',        type=float, default=-1.,  help="estimated eps")
parser.add_argument('--worst_case',     type=int,   default=0,    help="boolean for considering w.c. distribution or not")
parser.add_argument('--data_dir',       type=str,   default='/home/gridsan/zwang1/LP-Robust-CP/datasets/ImageNet/val', help="dir to imagenet val data")
parser.add_argument('--save',           type=str,   default='experiments/imgnet', help="define the save directory")
args = parser.parse_args()


# define function for f-divergence
def f(x):
    return (x - 1) ** 2


"""
Set-up Stage
"""

# load pretrained model
weights = ResNet152_Weights.DEFAULT
model = resnet152()
state_dict = torch.load('../pretrained_models/resnet152-f82ba261.pth')
model.load_state_dict(state_dict)
model.to(device)

# load data transforms
preprocess = weights.transforms()

# instantiate robust cp class
lp_robust_cp = LPRobustCP(model, nll_score, args.alpha)
fdiv_robust_cp = FDivRobustCP(rho=args.fdiv_radius, tol=1e-12, f=f, is_chisq=True)

# sample for random seeds
seed_range = 100000
num_seeds = args.num_trials
unique_seeds = random.sample(range(seed_range), num_seeds)

columns = ["standard_coverage", "lp_robust_coverage","lp_robust_coverage_est", "fdiv_robust_coverage", 
           "standard_avgsize", "lp_robust_avgsize", "lp_robust_avgsize_est", "fdiv_robust_avgsize", "rho_est"]
result_hist = pd.DataFrame(columns=columns)

for seed in unique_seeds:
    # load dataset
    cal_loader, test_loader = load_imgnet_valdata(args.data_dir, preprocess, cal_ratio=args.cal_ratio,
                                                  batch_size=args.batch_size, seed=seed)
    """
    Conformal Prediction Stage
    """
    # obtain calibration and test scores
    calib_scores, calib_labels, tst_scores, tst_labels = lp_robust_cp.get_scores(cal_loader, test_loader,
                                                                                 corrupt_ratio=args.corrupt_ratio,
                                                                                 noise_upper=args.noise_upper,
                                                                                 noise_lower=args.noise_lower,
                                                                                 worst_case=bool(args.worst_case))
    calib_scores = calib_scores.cpu().numpy()
    calib_labels = calib_labels.cpu().numpy()
    tst_scores = tst_scores.cpu().numpy()
    tst_labels = tst_labels.cpu().numpy()

    # obtain calibration scores
    cal_scores = calib_scores[np.arange(calib_scores.shape[0]), calib_labels]
    # tst_scores_for_ot = tst_scores
    # tst_scores_for_ot = tst_scores_for_ot[np.arange(tst_scores_for_ot.shape[0]), tst_labels]
    # tst_scores_for_ot = np.random.choice(tst_scores_for_ot, size=1000, replace=False)

    # standard cp quantile
    qhat = lp_robust_cp.standard_quantile(cal_scores)

    # lp robust cp quantile
    rho = args.corrupt_ratio
    epsilon = np.max(np.abs((args.noise_upper, args.noise_lower)))
    lp_robust_qhat = lp_robust_cp.lp_robust_quantile(cal_scores, rho=rho, epsilon=epsilon, k=2.)
    
    # lp robust cp with estimated rho
    lp_robust_qhat_est = lp_robust_cp.lp_robust_quantile(cal_scores, rho=args.rho_est, epsilon=args.eps_est, k=2.)

    # f-div robust cp quantile
    fdiv_robust_qhat = fdiv_robust_cp.adjusted_quantile(cal_scores, cal_scores.shape[0], args.alpha)

    # form prediction sets
    prediction_sets = tst_scores <= qhat
    lp_prediction_sets = tst_scores <= lp_robust_qhat
    lp_prediction_est_sets = tst_scores <= lp_robust_qhat_est
    fdiv_prediction_sets = tst_scores <= fdiv_robust_qhat

    """
    Evaluation Stage
    """
#     # compute empirical coverage
#     empirical_coverage = prediction_sets[np.arange(prediction_sets.shape[0]), tst_labels].mean()
#     lp_robust_coverage = lp_prediction_sets[np.arange(lp_prediction_sets.shape[0]), tst_labels].mean()
#     fdiv_robust_coverage = fdiv_prediction_sets[np.arange(fdiv_prediction_sets.shape[0]), tst_labels].mean()
#     print(f"The standard coverage under rho={rho}, eps={epsilon} is: {empirical_coverage: .3f}")
#     print(f"The LP robust coverage under rho={rho}, eps={epsilon} is: {lp_robust_coverage: .3f}")
#     print(f"The f-div robust coverage under rho={rho}, eps={epsilon} is: {fdiv_robust_coverage: .3f}")

#     # compute average prediction set width
#     avg_width = np.mean(np.sum(prediction_sets, axis=1))
#     lp_robust_avgwidth = np.mean(np.sum(lp_prediction_sets, axis=1))
#     fdiv_robust_avgwidth = np.mean(np.sum(fdiv_prediction_sets, axis=1))
#     print(f"The average standard width under rho={rho}, eps={epsilon} is: {avg_width: .3f}")
#     print(f"The average LP robust width under rho={rho}, eps={epsilon} is: {lp_robust_avgwidth: .3f}")
#     print(f"The average f-div robust width under rho={rho}, eps={epsilon} is: {fdiv_robust_avgwidth: .3f}")

#     result_hist.loc[len(result_hist.index)] = [empirical_coverage, lp_robust_coverage, fdiv_robust_coverage,
#                                                avg_width, lp_robust_avgwidth, fdiv_robust_avgwidth, seed]
    # compute empirical coverage
    empirical_coverage = prediction_sets[np.arange(prediction_sets.shape[0]), tst_labels].mean()
    lp_robust_coverage = lp_prediction_sets[np.arange(lp_prediction_sets.shape[0]), tst_labels].mean()
    lp_robust_coverage_est = lp_prediction_est_sets[np.arange(lp_prediction_est_sets.shape[0]), tst_labels].mean()
    fdiv_robust_coverage = fdiv_prediction_sets[np.arange(fdiv_prediction_sets.shape[0]), tst_labels].mean()
    print(f"The empirical coverage is: {empirical_coverage: .3f}")
    print(f"The LP robust coverage is: {lp_robust_coverage: .3f}")
    print(f"The LP robust coverage w estimated rho is: {lp_robust_coverage_est: .3f}")
    print(f"The f-div robust coverage is: {fdiv_robust_coverage: .3f}")

    # compute average prediction set width
    avg_width = np.mean(np.sum(prediction_sets, axis=1))
    lp_robust_avgwidth = np.mean(np.sum(lp_prediction_sets, axis=1))
    lp_robust_avgwidth_est = np.mean(np.sum(lp_prediction_est_sets, axis=1))
    fdiv_robust_avgwidth = np.mean(np.sum(fdiv_prediction_sets, axis=1))
    print(f"The average width is: {avg_width: .3f}")
    print(f"The average LP robust width is: {lp_robust_avgwidth: .3f}")
    print(f"The average LP robust width w estimated rho is: {lp_robust_avgwidth_est: .3f}")
    print(f"The average f-div robust width is: {fdiv_robust_avgwidth: .3f}")

    result_hist.loc[len(result_hist.index)] = [empirical_coverage, lp_robust_coverage, lp_robust_coverage_est, fdiv_robust_coverage,
                                               avg_width, lp_robust_avgwidth, lp_robust_avgwidth_est, fdiv_robust_avgwidth, args.rho_est]

# save the results
if not os.path.exists(args.save):
    os.makedirs(args.save)
    
if args.worst_case == 1:
    result_hist.to_csv(os.path.join(args.save, f'%s_result_hist_{args.corrupt_ratio}_{args.noise_upper}_{args.noise_lower}.csv' % 'wc'))
else:
    result_hist.to_csv(os.path.join(args.save, f'%s_result_hist_{args.corrupt_ratio}_{args.noise_upper}_{args.noise_lower}.csv' % 'reg'))

# plotting
results_file = os.path.join(args.save, f'%s_result_hist_{args.corrupt_ratio}_{args.noise_upper}_{args.noise_lower}.csv' % 'reg')
reg_results = pd.read_csv(results_file).to_numpy()
coverage_results = [reg_results[:, i] for i in range(1, 5)]
size_results = [reg_results[:, j] for j in range(5, 9)]
plot_cp(coverage_results, plt_type='Coverage', plt_name=f'imgnet_{args.corrupt_ratio}_{args.noise_upper}_{args.noise_lower}_cover.png', save_dir='figures')
plot_cp(size_results, plt_type='Size', plt_name=f'imgnet_{args.corrupt_ratio}_{args.noise_upper}_{args.noise_lower}_size.png', save_dir='figures')


