# google/vit-base-patch16-224-in21k
"""
python actions/vanilla/train_vanilla_imagenet.py --need-cls
"""
import torch
# from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
import os
import argparse
# import evaluate
from tqdm.auto import tqdm
from transformers import ViTImageProcessor, AutoModel
from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))
import numpy as np
import random
from tqdm import tqdm
import h5py
from PIL import Image
import math
from skimage.transform import resize


def transform(image, processor=None):
    # Preprocess the image using the ViTImageProcessor
    image = image.convert("RGB")
    if processor is not None:
        inputs = processor(image, return_tensors='pt')
        return inputs['pixel_values'].squeeze(0)
    else:
        return np.asarray(image)
        

def parse_args():
    parser = argparse.ArgumentParser()

    # paths and info
    parser.add_argument('--blackbox-model-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box model name')
    parser.add_argument('--blackbox-processor-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box processor name')
    parser.add_argument('--class-names-filepath', type=str, 
                        default=None,
                        help='class names filepath')
    parser.add_argument('--image-dir', type=str, 
                        default='../datasets/SUN397/images', 
                        help='input dir')
    parser.add_argument('--output-dir', type=str, 
                        default='../datasets/SUN397/images_feats', 
                        help='input dir')
    parser.add_argument('--start-dir', type=int, 
                        default=0, 
                        help='start dir')
    parser.add_argument('--end-dir', type=int, 
                        default=-1, 
                        help='end dir')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--flat', 
                        default=False,
                        action='store_true', 
                        help='if true, treat as flat')
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('device', device)

    # Define your dataset and dataloader
    processor = ViTImageProcessor.from_pretrained(args.blackbox_processor_name)
    
    
    # test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

    # Define your model and optimizer
    model = AutoModel.from_pretrained(args.blackbox_processor_name)
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        if not args.flat:
            if args.class_names_filepath is None:
                dirnames = list(sorted(os.listdir(args.image_dir)))
            else:
                with open(args.class_names_filepath, 'rt') as input_file:
                    dirnames = [line.strip()[1:] for line in input_file.readlines()]
            
            if args.end_dir != -1:
                dirnames = dirnames[args.start_dir:args.end_dir]
            else:
                dirnames = dirnames[args.start_dir:]

            for dirname in tqdm(dirnames, file=sys.stdout):
            # for dirname in tqdm(os.listdir(args.image_dir)):
                os.makedirs(os.path.join(args.output_dir, dirname), exist_ok=True)
                for filename in tqdm(os.listdir(os.path.join(args.image_dir, dirname))):
                    feat_filename = filename + '.npy'
                    image_path = os.path.join(args.image_dir, 
                                            dirname, 
                                            filename)
                    output_path = os.path.join(args.output_dir, 
                                            dirname, 
                                            feat_filename)

                    if os.path.exists(output_path):
                        continue
                    
                    try:
                        image = Image.open(image_path)
                    except:
                        print('failed ', image_path + '\n\n\n')
                        continue
                    try:
                        image = transform(image, processor=processor).to(device)
                    except:
                        print('failed transform', image_path + '\n\n\n')
                        continue
                    # print(image.shape)
                    output = model(image.unsqueeze(0)).last_hidden_state
                    np.save(output_path, output.cpu().numpy())

        else:
            os.makedirs(args.output_dir, exist_ok=True)
            for filename in tqdm(os.listdir(args.image_dir), file=sys.stdout):
                feat_filename = filename + '.npy'
                image_path = os.path.join(args.image_dir, filename)
                output_path = os.path.join(args.output_dir, feat_filename)

                if os.path.exists(output_path):
                    continue
                
                image = Image.open(image_path)
                image = transform(image, processor=processor).to(device)
                image = transform(image, processor=processor).to(device)
                output = model(image.unsqueeze(0)).last_hidden_state
                np.save(output_path, output.cpu().numpy())

        
if __name__ == '__main__':
    main()