import argparse
import os

rank = int(os.getenv("RANK", 0))
local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["LOCAL_RANK"]

import clip
import mxnet as mx
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset


class ImageRecDataset(Dataset):
    def __init__(self, path_imgidx, path_imgrec, transform):
        super(ImageRecDataset, self).__init__()
        self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
        self.imgidx = np.array(list(self.imgrec.keys))
        num_file = len(self.imgidx)
        num_local = num_file // world_size + int(rank < num_file % world_size)
        start = num_file // world_size * rank + min(rank, num_file % world_size)
        self.imgidx = self.imgidx[start: start + num_local]
        self.transform = transform

    def __getitem__(self, index):
        idx = self.imgidx[index]
        s = self.imgrec.read_idx(idx)
        header, jpeg = mx.recordio.unpack(s)
        sample = mx.image.imdecode(jpeg).asnumpy()

        # RGB
        sample = Image.fromarray(sample)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample

    def __len__(self, ):
        return len(self.imgidx)

@torch.no_grad()
def extract_torch(data_iter,
                  model_torch,
                  chunk_size: int,
                  rank: int,
                  output: str, ):
    assert isinstance(model_torch, torch.nn.Module)
    with torch.no_grad():

        idx = 0
        chunk_id = 0
        gpu_cache = None
        for i, data_batch in enumerate(data_iter):

            feat: torch.Tensor = model_torch.encode_image(data_batch.cuda())
            if idx == 0 and chunk_id == 0:
                gpu_cache: torch.Tensor = torch.zeros(
                    size=(chunk_size, feat.shape[1]),
                    dtype=torch.float32,
                    device=0)
            current_batch_size = feat.shape[0]
            gpu_cache[idx: idx + current_batch_size] = feat
            idx += current_batch_size
            if idx == chunk_size:
                np.save("{}_{}_{}.npy".format(output, rank, chunk_id), gpu_cache.cpu().numpy())
                chunk_id += 1
                idx = 0
                gpu_cache.fill_(0.)
        np.save("{}_{}_{}.npy".format(output, rank, chunk_id), gpu_cache[:idx].cpu().numpy())

@torch.no_grad()
def main(args):
    chunk_size = args.chunk_size * args.batch_size
    model_clip, preprocess = clip.load("/models_pretrained/clip/ViT-L-14-336px.pt")

    dataset = ImageRecDataset(
        f"{args.prefix}.idx",
        f"{args.prefix}.rec",
        preprocess,
    )
    data_iter = DataLoader(dataset, args.batch_size, shuffle=False)

    model_clip.eval().cuda()
    extract_torch(data_iter, model_clip, chunk_size, rank, args.output)


def merge(args):
    id_gpu = 0
    id_chunk = 0
    file_list = []
    while id_gpu < 10000:
        while id_chunk < 10000:
            if os.path.exists("{}_{}_{}.npy".format(args.output, id_gpu, id_chunk)):
                file_list.append("{}_{}_{}.npy".format(args.output, id_gpu, id_chunk))
                id_chunk += 1
            else:
                break
        id_gpu += 1
        id_chunk = 0
        if not os.path.exists("{}_{}_{}.npy".format(args.output, id_gpu, id_chunk)):
            break
    print(file_list)
    feat = np.concatenate([np.load(x) for x in file_list])
    np.save("{}.npy".format(args.output), feat)

    if args.rm:
        for x in file_list:
            os.remove(x)
        print("file_list cache has been removed!")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(usage="it's usage tip.", description="help info.")
    parser.add_argument("--output", type=str, default="./output", help="required, path for output path")
    parser.add_argument("--batch-size", type=int, default=1, help="batch size for each gpu, default is 32")
    parser.add_argument('--chunk-size', type=int, default=2000, help="npy size for each chunk, default is 200000")
    parser.add_argument("--prefix", type=str)
    parser.add_argument("--rm", type=int, default=0, help="remove cache, default is 1")
    main(parser.parse_args())
