import argparse
import os
import numpy as np
import torch
import pickle

from PIL import Image
from models import get_model
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

import torchvision.transforms as transforms

# cuda adaption
device = 'cuda' if torch.cuda.is_available() else 'cpu'
is_cuda = torch.cuda.is_available()


class FolderReadDataset(Dataset):
    def __init__(self, path, custom_name_txt=None, transform=None):
        self.path = path
        self.file_names = []
        self.transform = transform
        if custom_name_txt is None:
            # get all file names in the folder
            for fname in os.listdir(path):
                self.file_names.append(fname)
        else:
            # TODO: edit this later when Skye figured out the the 100K image names from SNAP
            with open(custom_name_txt, 'r') as f:
                self.file_names = f.read().split('\n')
            self.file_names = [x+'.jpg' for x in self.file_names]
            self.file_names = self.file_names[:len(self.file_names)-1]
    
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, index):
        name = self.file_names[index]
        x = self.get_image_from_folder(name)
        if self.transform:
            x = self.transform(x)
        return name, x

    def get_image_from_folder(self, name):
        image = Image.open(os.path.join(self.path, name))
        return image


def main(args):
    network = get_model(args.model)
    t = transforms.Compose(
        [transforms.Resize(size=224), 
        transforms.ToTensor()])
    dataset = FolderReadDataset(args.data, custom_name_txt=args.data_custom_name, transform=t)
    dataloader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=False)

    embedding_list = []
    name_list = []
    i = 0
    pbar = tqdm(total=len(dataloader))
    nework = network.to(device)
    for name, image in dataloader:
        image = image.to(device)
        embedding = network(image).cpu().detach().numpy()
        embedding_list.append(embedding)
        name_list.extend(name)
        i += 1
        pbar.update(1)
    embedding_list = np.concatenate(embedding_list)
    print(f'Final collected embedding features shape is {embedding_list.shape}')
    with open(args.out_filename, 'wb') as f:
        pickle.dump((name_list, embedding_list), f)
    print(f'Saved to {args.out_filename}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='parameterised models')
    parser.add_argument(
        '--model', 
        type=str, default='resnet18',
        help='model name')
    parser.add_argument(
        '--data', 
        type=str, default='./data',
        help='path to the folder that contains a list of images')
    parser.add_argument(
        '--data_custom_name', 
        type=str,
        help='a file that contains a list of images')
    parser.add_argument(
        '--batch_size', 
        type=int, default=128,
        help='batch size')
    parser.add_argument(
        '--out_filename', 
        type=str, default='tmp.pkl',
        help='out file name')
    args = parser.parse_args()
    main(args)
