
import torch
import os
from tqdm import tqdm
import numpy as np
import argparse
from pathlib import Path

# Import saliency methods and models
from ViT_explanation_generator import Baselines
from ViT_new import vit_base_patch16_224, vit_large_patch16_224
# from models.vgg import vgg19
import glob

from sklearn import metrics
from dataset.expl_hdf5_optimized import ImagenetResults


def normalize(tensor,
              mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
    return tensor



def eval(args):
    num_samples = 0
    # num_correct_model = np.zeros((len(imagenet_ds,)))
    num_correct_model = torch.zeros((len(imagenet_ds,)))
    dissimilarity_model = np.zeros((len(imagenet_ds,)))
    model_index = 0

    if args.scale == 'per':
        base_size = 224 * 224
        perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    elif args.scale == '100':
        base_size = 100
        perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45]
    else:
        raise Exception('scale not valid')

    # num_correct_pertub = np.zeros((9, len(imagenet_ds)))
    num_correct_pertub = torch.zeros((9, len(imagenet_ds)))
    dissimilarity_pertub = np.zeros((9, len(imagenet_ds)))
    logit_diff_pertub = np.zeros((9, len(imagenet_ds)))
    prob_diff_pertub = np.zeros((9, len(imagenet_ds)))
    perturb_index = 0

    pbar = tqdm(sample_loader)
    for batch_idx, (data, vis, target) in enumerate(pbar):
        # vis (batch, 1, 224, 224) for 'full_lrp'
        # vis (batch, 1, 14, 14) for other methods
        assert vis.shape[-2:] == (14, 14) or vis.shape[-2:] == (224, 224)
        if vis.shape[-2:] == (14, 14):
            vis = torch.nn.functional.interpolate(vis, scale_factor=16, mode='bilinear', align_corners=False) # (batch, 1, 224, 224)
        vis = vis - vis.amin(dim=(-2,-1), keepdims=True) # (batch, 1, 224, 224)
        vis = vis / vis.amax(dim=(-2,-1), keepdims=True) # (batch, 1, 224, 224)

        # Update the number of samples
        num_samples += len(data)

        data = data.to(device) # (batch, C, H, W)
        vis = vis.to(device) # (batch, 1, H, W)
        target = target.to(device) # (batch, )
        norm_data = normalize(data.clone()) # (batch, C, H, W)
        ##################################
        # with torch.no_grad():
        #     norm_data = norm_data.unsqueeze(1) # (batch, 1, C, H, W)
        #     norm_data = norm_data.expand(-1,1,-1,-1,-1).clone() # (batch, 10, C, H, W)
        #     norm_data = norm_data.flatten(0, 1) # (batch*10, C, H, W)
        #     pred = model(norm_data) # (batch, classes)
        # continue
        ##################################

        # Compute model accuracy
        with torch.no_grad():
            pred = model(norm_data) # (batch, classes)
        # pred_probabilities = torch.softmax(pred, dim=1)  # (batch, classes)
        # pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1) # (batch, )
        # pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) # (batch, )
        pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1) # (batch, )
        # tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy() # (batch, )
        tgt_pred = (target == pred_class).type(target.type()).detach() # (batch, )
        num_correct_model[model_index:model_index+len(tgt_pred)] = tgt_pred

        # probs = torch.softmax(pred, dim=1) # (batch, classes)
        # target_probs = torch.gather(probs, 1, target[:, None])[:, 0] # (batch, )
        # second_probs = probs.data.topk(2, dim=1)[0][:, 1] # (batch, )
        # temp = torch.log(target_probs / second_probs).data.cpu().numpy() # (batch, )
        # dissimilarity_model[model_index:model_index+len(temp)] = temp

        if args.wrong:
            wid = np.argwhere(tgt_pred == 0).flatten()
            if len(wid) == 0:
                continue
            wid = torch.from_numpy(wid).to(vis.device)
            vis = vis.index_select(0, wid)
            data = data.index_select(0, wid)
            target = target.index_select(0, wid)

        # Save original shape
        org_shape = data.shape

        if args.neg:
            vis = -vis

        vis = vis.reshape(org_shape[0], -1) # (batch, H*W)

        arr = []
        idx_sort = vis.argsort(dim=-1, descending=True)
        for i in range(len(perturbation_steps)):
            _data = data.clone() # (batch, C, H, W)

            # _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1) # (batch, perturbation_steps[i]*H* W)
            idx = idx_sort[:, :int(base_size * perturbation_steps[i])]
            idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1) # (batch, C, perturbation_steps[i]*H* W)
            _data = _data.reshape(org_shape[0], org_shape[1], -1) # (batch, C, H*W)
            _data = _data.scatter_(-1, idx, 0) # (batch, C, H*W)
            _data = _data.reshape(*org_shape) # (batch, C, H, W)

            _norm_data = normalize(_data) # (batch, C, H, W)
            arr.append(_norm_data)
        _data_perturb = torch.stack(arr, dim=1) # (batch, perturb, C, H, W)

        with torch.no_grad():
            n_batch = _data_perturb.shape[0]
            _data_perturb = _data_perturb.flatten(0, 1) # (batch*perturb, C, H, W)
            out = model(_data_perturb) # (batch*perturb, classes)
            # out = torch.zeros(n_batch*len(perturbation_steps),1000, device=device) # (batch*perturb, classes)

            out = out.unflatten(0, (n_batch, len(perturbation_steps))) # (batch, perturb, classes)

            # pred_probabilities = torch.softmax(out, dim=1) # (batch, classes)
            # pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) # (batch, )
            # diff = (pred_prob - pred_org_prob).data.cpu().numpy() # (batch, )
            # prob_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff

            # pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1) # (batch, )
            # diff = (pred_logit - pred_org_logit).data.cpu().numpy() # (batch, )
            # logit_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff

        target_class = out.max(-1, keepdim=True).indices.squeeze(-1) # (batch, perturb)
        # temp = (target == target_class).type(target.type()).data.cpu().numpy() # (batch, perturb)
        temp = (target.unsqueeze(-1) == target_class).type(target.type()).detach() # (batch, perturb)
        num_correct_pertub[:, perturb_index:perturb_index+len(temp)] = temp.T

            # probs_pertub = torch.softmax(out, dim=1) # (batch, classes)
            # target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0] # (batch, )
            # second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1] # (batch, )
            # temp = torch.log(target_probs / second_probs).data.cpu().numpy() # (batch, )
            # dissimilarity_pertub[i, perturb_index:perturb_index+len(temp)] = temp

        model_index += len(target)
        perturb_index += len(target)

        if batch_idx % 10 == 0 or batch_idx == len(sample_loader) - 1:
            # Compute AUC and update in the progress bar
            steps = [0] + perturbation_steps
            res = [num_correct_model[:perturb_index].mean()] + num_correct_pertub[:, :perturb_index].mean(dim=-1).tolist()
            auc = metrics.auc(steps, res)
            pbar.set_postfix({"AUC": f"{auc*100:.2f}"})

    np.save(os.path.join(args.experiment_dir, 'model_hits.npy'), num_correct_model)
    # np.save(os.path.join(args.experiment_dir, 'model_dissimilarities.npy'), dissimilarity_model)
    np.save(os.path.join(args.experiment_dir, 'perturbations_hits.npy'), num_correct_pertub[:, :perturb_index])
    # np.save(os.path.join(args.experiment_dir, 'perturbations_dissimilarities.npy'), dissimilarity_pertub[:, :perturb_index])
    # np.save(os.path.join(args.experiment_dir, 'perturbations_logit_diff.npy'), logit_diff_pertub[:, :perturb_index])
    # np.save(os.path.join(args.experiment_dir, 'perturbations_prob_diff.npy'), prob_diff_pertub[:, :perturb_index])

    print(f"{num_correct_model.mean()=}, {num_correct_model.std()=}")
    # print(np.mean(dissimilarity_model), np.std(dissimilarity_model))
    print(perturbation_steps)
    print(f"{num_correct_pertub.mean(dim=1)=}, {num_correct_pertub.std(dim=1)=}")
    # print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train a segmentation')
    parser.add_argument('--batch-size', type=int,
                        default=16,
                        help='')
    parser.add_argument('--neg', type=bool,
                        default=False,
                        help='')
    parser.add_argument('--value', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--scale', type=str,
                        default='per',
                        choices=['per', '100'],
                        help='')
    parser.add_argument('--method', type=str,
                        default='grad_rollout',
                        # choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'v_gradcam', 'lrp_last_layer',
                        #          'lrp_second_layer', 'gradcam', 
                        #          'predmap',
                        #          'predmap9',
                        #          'predmap13',
                        #          'predmap15',
                        #          'predmap15_batched_layer',
                        #          'predmap17',
                        #          'attn_last_layer', 'attn_gradcam', 'input_grads'],
                        help='')
    parser.add_argument('--vis-class', type=str,
                        default='top',
                        choices=['top', 'target', 'index'],
                        help='')
    parser.add_argument('--wrong', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--class-id', type=int,
                        default=0,
                        help='')
    parser.add_argument('--is-ablation', type=bool,
                        default=False,
                        help='')
    parser.add_argument('--arc', type=str,
                        default='vit-b',
                        choices=['vit-b', 'vit-l'],
                        help='')
    parser.add_argument('--dataset-name', type=str,
                        required=True,
                        help='')
    parser.add_argument('--imagenet-validation-path', type=str,
                        required=True,
                        help='')
    args = parser.parse_args()

    torch.multiprocessing.set_start_method('spawn')

    # PATH variables
    PATH = os.path.dirname(os.path.abspath(__file__)) + '/optimized/'
    PATH = Path(__file__).parent / 'optimized'
    PATH = PATH / f"{args.dataset_name}/"
    PATH = PATH / f"{args.arc}/"
    os.makedirs(os.path.join(PATH, 'experiments'), exist_ok=True)
    os.makedirs(os.path.join(PATH, 'experiments/perturbations'), exist_ok=True)

    exp_name = args.method
    exp_name += '_neg' if args.neg else '_pos'
    print(exp_name)

    if args.vis_class == 'index':
        args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}_{}'.format(exp_name,
                                                                                       args.vis_class,
                                                                                       args.class_id))
    else:
        ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
        args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}/{}'.format(exp_name,
                                                                                    args.vis_class, ablation_fold))
        # args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name,
        #                                                                             args.vis_class))

    if args.wrong:
        args.runs_dir += '_wrong'

    experiments = sorted(glob.glob(os.path.join(args.runs_dir, 'experiment_*')))
    experiment_id = int(experiments[-1].split('_')[-1]) + 1 if experiments else 0
    args.experiment_dir = os.path.join(args.runs_dir, 'experiment_{}'.format(str(experiment_id)))
    os.makedirs(args.experiment_dir, exist_ok=True)

    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if args.vis_class == 'index':
        vis_method_dir = os.path.join(PATH,'visualizations/{}/{}_{}'.format(args.method,
                                                          args.vis_class,
                                                          args.class_id))
    else:
        ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
        vis_method_dir = os.path.join(PATH,'visualizations/{}/{}/{}'.format(args.method,
                                                       args.vis_class, ablation_fold))
        # vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method,
        #                                                                      args.vis_class))

    # imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method))
    imagenet_ds = ImagenetResults(vis_method_dir, args.imagenet_validation_path)

    # Model
    if args.arc == 'vit-b':
        print("Loading ViT-B/16")
        model = vit_base_patch16_224(pretrained=True).cuda()
    elif args.arc == 'vit-l':
        print("Loading ViT-L/16")
        model = vit_large_patch16_224(pretrained=True).cuda()
    else:
        print(f"Architecture {args.arc} not supported")
    model.eval()

    save_path = PATH /'results/'

    sample_loader = torch.utils.data.DataLoader(
        imagenet_ds,
        batch_size=args.batch_size,
        num_workers=2,
        shuffle=False)

    try:
        eval(args)
    finally:
        imagenet_ds.cleanup()

