import os, h5py, argparse
from tqdm import tqdm
import numpy as np
import sys
sys.path.append(os.getcwd())
from torchvision.datasets import ImageNet

# Import saliency methods and models
from baselines.misc_functions import *
from baselines.models import run_baselines, load_baseline_models

from src.utilities import normalize, min_max_normal
from src.algorithms import batch_pgd_attack


def compute_saliency_and_save(config, models, sample_loader):
    first = True
    device = config['device']
    
    with h5py.File(os.path.join(config['method_dir'], 'results.hdf5'), 'a') as f:
        data_cam = f.create_dataset('vis',
                                    (1, 1, 224, 224),
                                    maxshape=(None, 1, 224, 224),
                                    dtype=np.float32,
                                    compression="gzip")
        data_image = f.create_dataset('image',
                                      (1, 3, 224, 224),
                                      maxshape=(None, 3, 224, 224),
                                      dtype=np.float32,
                                      compression="gzip")
        data_target = f.create_dataset('target',
                                       (1,),
                                       maxshape=(None,),
                                       dtype=np.int32,
                                       compression="gzip")

                                    
        for batch_idx, (data, target) in enumerate(tqdm(sample_loader)):
            if config['to_attack']:
                noise_level=config['noise_level']/255
                data = batch_pgd_attack(data.to(device), models['vit'], noise_level)

            if first:
                first = False
                data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
                data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
                data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
            else:
                data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
                data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
                data_target.resize(data_target.shape[0] + data.shape[0], axis=0)

            # Add data
            data_image[-data.shape[0]:] = data.data.cpu().numpy()
            data_target[-data.shape[0]:] = target.data.cpu().numpy()

            target = target.to(device)

            data = normalize(data)
            data = data.to(device)
            data.requires_grad_()

            index = None
            if config['vis_class'] == 'target':
                index = target
            
            Res = run_baselines(models, config['method'], data, device, config['is_ablation'])

            data_cam[-data.shape[0]:] = Res.data.cpu().numpy()


def generate_visualizations(config, base_vit='vit'):
    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # PATH variables
    PATH = os.getcwd()
    os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True)
    attack_type = 'attack_'+str(config['noise_level']) if config['to_attack'] else 'no_attack'

    os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(config['method'])), exist_ok=True)
    if config['vis_class'] == 'index':
        try:
            os.remove(os.path.join(PATH, 'visualizations/{}/{}_{}/results.hdf5'.format(config['method'],
                                                                                    config['vis_class'],
                                                                                    config['class_id'])))
        except OSError:
            pass
        os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(config['method'],
                                                                        config['vis_class'],
                                                                        config['class_id'])), exist_ok=True)
        config['method_dir'] = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(config['method'],
                                                                              config['vis_class'],
                                                                              config['class_id']))
    else:
        ablation_fold = 'ablation' if config['is_ablation'] else 'not_ablation'
        try:
            os.remove(os.path.join(PATH, 'visualizations/{}/{}/{}/{}/results.hdf5'.format(config['method'], config['vis_class'],
                                                                                    ablation_fold, attack_type)))
        except OSError:
            pass
        os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}/{}'.format(config['method'],
                                                                     config['vis_class'], ablation_fold, attack_type)), exist_ok=True)
        config['method_dir'] = os.path.join(PATH, 'visualizations/{}/{}/{}/{}'.format(config['method'],
                                                                           config['vis_class'], ablation_fold, attack_type))

    # Model
    models = load_baseline_models(base_vit, config['device'])

    # Dataset loader for sample images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    imagenet_ds = ImageNet(config['imagenet_validation_path'], split='val', transform=transform)
    
    sample_loader = torch.utils.data.DataLoader(
        imagenet_ds,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=4
    )

    compute_saliency_and_save(config, models, sample_loader)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train a segmentation')
    parser.add_argument('--batch-size', type=int, default=1, help='')
    parser.add_argument('--method', type=str, default='rollout',
                        choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
                                 'attn_last_layer', 'attn_gradcam', 'dds'], help='')
    parser.add_argument('--vis-class', type=str, default='top', choices=['top', 'target', 'index'], help='')
    parser.add_argument('--class-id', type=int, default=0, help='')
    parser.add_argument('--is-ablation', type=bool, default=False, help='')
    parser.add_argument('--imagenet-validation-path', type=str, default='data_dir/', help='')
    parser.add_argument('--base-vit', type=str, default='vit', choices=['vit', 'swin', 'deit'], help='Type of ViT to use')
    parser.add_argument('--to-attack', action='store_true', default=False, help='Boolean to perform default(PGD) attack')
    parser.add_argument('--noise-level', type=int, default=4, help='Noise level out of 255')

    # Unused args
    # parser.add_argument('--lmd', type=float, default=10, help='')
    # parser.add_argument('--no-ia', action='store_true', default=False, help='')
    # parser.add_argument('--no-m', action='store_true', default=False, help='')
    # parser.add_argument('--no-fgx', action='store_true', default=False, help='')
    # parser.add_argument('--no-fx', action='store_true', default=False, help='')
    # parser.add_argument('--no-reg', action='store_true', default=False, help='')
    # parser.add_argument('--cls-agn', action='store_true', default=False, help='')

    args = parser.parse_args()

    args = vars(args)

    generate_visualizations(args, args['base_vit'])
