
import os
from pathlib import Path
from tqdm import tqdm
import h5py

import argparse

# Import saliency methods and models
from misc_functions import *

from ViT_explanation_generator import Baselines, LRP
from ViT_new import vit_base_patch16_224, vit_large_patch16_224
from ViT_LRP import vit_base_patch16_224 as vit_LRP
from ViT_LRP import vit_large_patch16_224 as vit_large_LRP
from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
from ViT_orig_LRP import vit_large_patch16_224 as vit_large_orig_LRP


import torch.utils.data
from torchvision.datasets import ImageNet

import lovely_tensors as lt
lt.monkey_patch()


def normalize(tensor,
              mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
    return tensor

def get_baselines(args):
    if args.arc == 'vit-l':
        print("Loading vit_large_patch16_224...")
        model = vit_large_patch16_224(pretrained=True).cuda()
    else:
        print("Loading vit_base_patch16_224...")
        model = vit_base_patch16_224(pretrained=True).cuda()
    baselines = Baselines(model)
    return baselines

def get_lrp(args):
    if args.arc == 'vit-l':
        print("Loading vit_large_LRP...")
        model_LRP = vit_large_LRP(pretrained=True).cuda()
    else:
        print("Loading vit_LRP...")
        model_LRP = vit_LRP(pretrained=True).cuda()
    model_LRP.eval()
    lrp = LRP(model_LRP)
    return lrp

def get_orig_lrp(args):
    if args.arc == 'vit-l':
        print("Loading vit_large_orig_LRP...")
        model_orig_LRP = vit_large_orig_LRP(pretrained=True).cuda()
    else:
        print("Loading vit_orig_LRP...")
        model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
    model_orig_LRP.eval()
    orig_lrp = LRP(model_orig_LRP)
    return orig_lrp

def get_method_func(args):
    if args.method == 'rollout':
        baselines = get_baselines(args)
        method_func = lambda data, index: baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14)
    elif args.method == 'lrp': #SAME AS 'transformer_attribution'
        lrp = get_lrp(args)
        method_func = lambda data, index: lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14)
    elif args.method == 'transformer_attribution':
        lrp = get_lrp(args)
        method_func = lambda data, index: lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14)
    elif str(args.method).startswith('predmap'):
        baselines = get_baselines(args)
        model = baselines.model
        method_invoke = getattr(model, args.method)
        kwargs = {}
        if args.layer is not None:
            kwargs['layer'] = args.layer
        method_func = lambda data, index: method_invoke(data, idx=index, **kwargs).reshape(data.shape[0], 1, 14, 14)
    elif args.method == 'full_lrp':
        orig_lrp = get_orig_lrp(args)
        method_func = lambda data, index: orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224)
    elif args.method == 'lrp_last_layer':
        orig_lrp = get_orig_lrp(args)
        method_func = lambda data, index: orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \
            .reshape(data.shape[0], 1, 14, 14)
    elif args.method == 'attn_last_layer':
        lrp = get_lrp(args)
        method_func = lambda data, index: lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \
            .reshape(data.shape[0], 1, 14, 14)
    elif args.method == 'attn_gradcam':
        baselines = get_baselines(args)
        method_func = lambda data, index: baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14)
    
    return method_func



def compute_saliency_and_save(args):
    first = True
    n = len(imagenet_ds)
    db_path = os.path.join(args.method_dir, 'results.hdf5')
    print(f"DB path: {db_path}")
    with h5py.File(db_path, 'a') as f:
        if args.method == 'full_lrp':
            data_cam_shape = (1, 224, 224)
        else:
            data_cam_shape = (1, 14, 14)
        data_cam = f.create_dataset('vis',
                                    # (n, 1, 224, 224),
                                    (1, *data_cam_shape),
                                    # chunks=(1, 1, 224, 224),
                                    # maxshape=(None, 1, 224, 224),
                                    maxshape=(None, *data_cam_shape),
                                    dtype=np.float32,
                                    compression="gzip")
        data_image = f.create_dataset('image',
                                    #   (n, 3, 224, 224),
                                      (1, 3, 224, 224),
                                    #   chunks=(1, 3, 224, 224),
                                      maxshape=(None, 3, 224, 224),
                                      dtype=np.float32,
                                      compression="gzip")
        data_target = f.create_dataset('target',
                                    #    (n,),
                                       (1,),
                                    #    chunks=(1,),
                                       maxshape=(None,),
                                       dtype=np.int32,
                                       compression="gzip")
        
        data_path = f.create_dataset('path',
                                       (1,),
                                    #    chunks=(1,),
                                       maxshape=(None,),
                                       dtype='S256', # string of length 256
                                       compression="gzip")
        
        method_func = get_method_func(args)
        
        for batch_idx, (data, target, path) in enumerate(tqdm(sample_loader)):
            # if batch_idx == 5:
                # break
            if first:
                first = False
                data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
                data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
                data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
                data_path.resize(data_path.shape[0] + data.shape[0] - 1, axis=0)
            else:
                data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
                data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
                data_target.resize(data_target.shape[0] + data.shape[0], axis=0)
                data_path.resize(data_path.shape[0] + data.shape[0], axis=0)


            # Add data
            # data_numpy = data.data.cpu().numpy()
            # data_image[-data.shape[0]:] = data_numpy
            # target_numpy = target.data.cpu().numpy()
            # data_target[-data.shape[0]:] = target_numpy

            target = target.to(device)

            data = normalize(data)
            data = data.to(device)
            # data.requires_grad_()

            index = None
            if args.vis_class == 'target':
                index = target

            # H,W = (224, 224) for full_lrp
            # H,W = (14, 14) for all others
            Res = method_func(data, index) # (B, 1, H, W)

            res_numpy = Res.data.cpu().numpy()
            data_cam[-data.shape[0]:] = res_numpy
            data_path[-data.shape[0]:] = [n.encode("ascii") for n in path]


def get_imagenet_relative_path(path: str):
    # Example: path = ~/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG
    path = Path(path)
    base_path = path.parent.parent.parent # ~/imagenet/
    relpath = path.relative_to(base_path) # val/n01440764/ILSVRC2012_val_00000293.JPEG
    return relpath


class ImageNetWithPaths(torch.utils.data.Dataset):
    def __init__(self, imagenet_ds):
        self.imagenet_ds = imagenet_ds

    def __getitem__(self, index: int):
        sample, target = self.imagenet_ds[index]
        path = self.imagenet_ds.imgs[index][0]
        path = str(get_imagenet_relative_path(path))
        return sample, target, path
    
    def __len__(self):
        return len(self.imagenet_ds)
    

def imagenet_subset(imagenet_ds: ImageNetWithPaths, imagenet_subset_path):
    with Path(imagenet_subset_path).open("r") as f:
        subset_paths = [line.rstrip() for line in f]
    print("Finding subset indices")
    imagenet_related_paths = [str(get_imagenet_relative_path(x[0])) for x in tqdm(imagenet_ds.imagenet_ds.imgs)]
    subset_indices = [imagenet_related_paths.index(x) for x in tqdm(subset_paths)]
    assert len(subset_indices) == len(subset_paths)
    import torch.utils.data
    # subset = torch.utils.data.Subset(ImageNetWithPaths(imagenet_ds), subset_indices)
    subset = torch.utils.data.Subset(imagenet_ds, subset_indices)
    return subset


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train a segmentation')
    parser.add_argument('--batch-size', type=int,
                        default=1,
                        help='')
    parser.add_argument('--method', type=str,
                        default='grad_rollout',
                        # choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
                        #          'attn_last_layer', 'attn_gradcam', 
                        #          'predmap',
                        #          'predmap9',
                        #          'predmap13',
                        #          'predmap15',
                        #          'predmap15_batched_layer',
                        #          'predmap17',
                        #          ],
                        help='')
    parser.add_argument('--layer', type=int,
                        default=None,
                        help='')
    parser.add_argument('--lmd', type=float,
                        default=10,
                        help='')
    parser.add_argument('--vis-class', type=str,
                        default='top',
                        choices=['top', 'target', 'index'],
                        help='')
    parser.add_argument('--class-id', type=int,
                        default=0,
                        help='')
    parser.add_argument('--cls-agn', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-ia', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-fx', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-fgx', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-m', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--no-reg', action='store_true',
                        default=False,
                        help='')
    parser.add_argument('--is-ablation', type=bool,
                        default=False,
                        help='')
    parser.add_argument('--dataset-type', type=str,
                        required=True,
                        choices=['imagenet'],
                        help='')
    parser.add_argument('--imagenet-validation-path', type=str,
                        required=True,
                        help='')
    parser.add_argument('--imagenet-subset-path', type=str,
                        default=None,
                        help='')
    parser.add_argument('--arc', type=str,
                        default='vit-b',
                        choices=['vit-b', 'vit-l'],
                        help='')
    args = parser.parse_args()

    # PATH variables
    # PATH = os.path.dirname(os.path.abspath(__file__)) + '/optimized/'
    PATH = Path(__file__).parent / 'optimized'
    if args.dataset_type == "imagenet":
        if args.imagenet_subset_path is not None:
            imagenet_subset_path = Path(args.imagenet_subset_path)
            if not imagenet_subset_path.exists():
                raise FileNotFoundError(f"imagenet_subset_path not found: {imagenet_subset_path}")
            PATH = PATH / Path(args.imagenet_subset_path).stem
        else:
            PATH = PATH / 'imagenet'
    else:
        PATH = PATH / args.dataset_type
    PATH = PATH / f"{args.arc}/"
    os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True)

    method_layer = f"{args.method}{args.layer}" if args.layer is not None else args.method

    results_path = PATH / f'visualizations/{method_layer}/{args.vis_class}/not_ablation/results.hdf5'
    print(f"Results path: {results_path}")
    if results_path.exists():
        print(f"Results already exist. Removing...", end=" ")
        results_path.unlink()
        print("REMOVED!")
    # try:
        # os.remove(os.path.join(PATH, 'visualizations/{}/{}/not_ablation/results.hdf5'.format(method_layer,
                                                                                # args.vis_class)))
    # except OSError:
        # pass


    os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(method_layer)), exist_ok=True)
    if args.vis_class == 'index':
        os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(method_layer,
                                                                        args.vis_class,
                                                                        args.class_id)), exist_ok=True)
        args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(method_layer,
                                                                              args.vis_class,
                                                                              args.class_id))
    else:
        ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
        os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(method_layer,
                                                                     args.vis_class, ablation_fold)), exist_ok=True)
        args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(method_layer,
                                                                           args.vis_class, ablation_fold))

    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    # Dataset loader for sample images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', transform=transform)
    imagenet_ds = ImageNetWithPaths(imagenet_ds)
    if args.imagenet_subset_path is not None:
        imagenet_ds = imagenet_subset(imagenet_ds, args.imagenet_subset_path)
    sample_loader = torch.utils.data.DataLoader(
        imagenet_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4
    )

    compute_saliency_and_save(args)
