import os
import random
import numpy as np
from PIL import Image
import torch

if __name__ != '__main__':
    import open_clip

os.environ['CUDA_VISIBLE_DEVICES'] = ''

def seed_all(seed = 0):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=False)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def inference_text(model, model_name, batches):
    y = []
    tokenizer = open_clip.get_tokenizer(model_name)
    with torch.no_grad():
        for x in batches:
            x = tokenizer(x)
            y.append(model.encode_text(x))
        return torch.stack(y)

def inference_image(model, preprocess_val, batches):
    y = []
    with torch.no_grad():
        for x in batches:
            x = torch.stack([preprocess_val(img) for img in x])
            y.append(model.encode_image(x))
        return torch.stack(y)
    
def forward_model(model, model_name, preprocess_val, image_batch, text_batch):
    y = []
    tokenizer = open_clip.get_tokenizer(model_name)
    with torch.no_grad():
        for x_im, x_txt in zip(image_batch, text_batch):
            x_im = torch.stack([preprocess_val(im) for im in x_im])
            x_txt = tokenizer(x_txt)
        y.append(model(x_im, x_txt))
    if type(y[0]) == dict:
        out = {}
        for key in y[0].keys():
            out[key] = torch.stack([batch_out[key] for batch_out in y])
    else:
        out = []
        for i in range(len(y[0])):
            out.append(torch.stack([batch_out[i] for batch_out in y]))
    return out

def random_image_batch(batch_size, size):
    h, w = size
    data = np.random.randint(255, size = (batch_size, h, w, 3), dtype = np.uint8)
    return [ Image.fromarray(d) for d in data ]

def random_text_batch(batch_size, min_length = 75, max_length = 75):
    t = open_clip.tokenizer.SimpleTokenizer()
    # every token decoded as string, exclude SOT and EOT, replace EOW with space
    token_words = [
            x[1].replace('</w>', ' ')
            for x in t.decoder.items()
            if x[0] not in t.all_special_ids
    ]
    # strings of randomly chosen tokens
    return [
        ''.join(random.choices(
                token_words,
                k = random.randint(min_length, max_length)
        ))
        for _ in range(batch_size)
    ]

def create_random_text_data(
        path,
        min_length = 75,
        max_length = 75,
        batches = 1,
        batch_size = 1
):
    text_batches = [
            random_text_batch(batch_size, min_length, max_length)
            for _ in range(batches)
    ]
    print(f"{path}")
    torch.save(text_batches, path)

def create_random_image_data(path, size, batches = 1, batch_size = 1):
    image_batches = [
            random_image_batch(batch_size, size)
            for _ in range(batches)
    ]
    print(f"{path}")
    torch.save(image_batches, path)

def get_data_dirs(make_dir = True):
    data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
    input_dir = os.path.join(data_dir, 'input')
    output_dir = os.path.join(data_dir, 'output')
    if make_dir:
        os.makedirs(input_dir, exist_ok = True)
        os.makedirs(output_dir, exist_ok = True)
    assert os.path.isdir(data_dir), f"data directory missing, expected at {input_dir}"
    assert os.path.isdir(data_dir), f"data directory missing, expected at {output_dir}"
    return input_dir, output_dir

def create_test_data_for_model(
        model_name,
        pretrained = None,
        precision = 'fp32',
        jit = False,
        pretrained_hf = False,
        force_quick_gelu = False,
        create_missing_input_data = True,
        batches = 1,
        batch_size = 1,
        overwrite = False
):
    model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
    input_dir, output_dir = get_data_dirs()
    output_file_text = os.path.join(output_dir, f'{model_id}_random_text.pt')
    output_file_image = os.path.join(output_dir, f'{model_id}_random_image.pt')
    text_exists = os.path.exists(output_file_text)
    image_exists = os.path.exists(output_file_image)
    if not overwrite and text_exists and image_exists:
        return
    seed_all()
    model, _, preprocess_val = open_clip.create_model_and_transforms(
            model_name,
            pretrained = pretrained,
            precision = precision,
            jit = jit,
            force_quick_gelu = force_quick_gelu,
            pretrained_hf = pretrained_hf
    )
    # text
    if overwrite or not text_exists:
        input_file_text = os.path.join(input_dir, 'random_text.pt')
        if create_missing_input_data and not os.path.exists(input_file_text):
            create_random_text_data(
                    input_file_text,
                    batches = batches,
                    batch_size = batch_size
            )
        assert os.path.isfile(input_file_text), f"missing input data, expected at {input_file_text}"
        input_data_text = torch.load(input_file_text)
        output_data_text = inference_text(model, model_name, input_data_text)
        print(f"{output_file_text}")
        torch.save(output_data_text, output_file_text)
    # image
    if overwrite or not image_exists:
        size = model.visual.image_size
        if not isinstance(size, tuple):
            size = (size, size)
        input_file_image = os.path.join(input_dir, f'random_image_{size[0]}_{size[1]}.pt')
        if create_missing_input_data and not os.path.exists(input_file_image):
            create_random_image_data(
                    input_file_image,
                    size,
                    batches = batches,
                    batch_size = batch_size
            )
        assert os.path.isfile(input_file_image), f"missing input data, expected at {input_file_image}"
        input_data_image = torch.load(input_file_image)
        output_data_image = inference_image(model, preprocess_val, input_data_image)
        print(f"{output_file_image}")
        torch.save(output_data_image, output_file_image)

def create_test_data(
        models,
        batches = 1,
        batch_size = 1,
        overwrite = False
):
    models = list(set(models).difference({
            # not available with timm
            # see https://github.com/mlfoundations/open_clip/issues/219
            'timm-convnext_xlarge',
            'timm-vit_medium_patch16_gap_256'
    }).intersection(open_clip.list_models()))
    models.sort()
    print(f"generating test data for:\n{models}")
    for model_name in models:
        print(model_name)
        create_test_data_for_model(
                model_name,
                batches = batches,
                batch_size = batch_size,
                overwrite = overwrite
        )
    return models

def _sytem_assert(string):
    assert os.system(string) == 0

class TestWrapper(torch.nn.Module):
    output_dict: torch.jit.Final[bool]
    def __init__(self, model, model_name, output_dict=True) -> None:
        super().__init__()
        self.model = model
        self.output_dict = output_dict
        if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]:
            self.model.output_dict = self.output_dict
        config = open_clip.get_model_config(model_name)
        self.head = torch.nn.Linear(config["embed_dim"], 2)

    def forward(self, image, text):
        x = self.model(image, text)
        x = x['image_features'] if self.output_dict else x[0]
        assert x is not None  # remove Optional[], type refinement for torchscript
        out = self.head(x)
        return {"test_output": out}

def main(args):
    global open_clip
    import importlib
    import shutil
    import subprocess
    import argparse
    parser = argparse.ArgumentParser(description = "Populate test data directory")
    parser.add_argument(
        '-a', '--all',
        action = 'store_true',
        help = "create test data for all models"
    )
    parser.add_argument(
        '-m', '--model',
        type = str,
        default = [],
        nargs = '+',
        help = "model(s) to create test data for"
    )
    parser.add_argument(
        '-f', '--model_list',
        type = str,
        help = "path to a text file containing a list of model names, one model per line"
    )
    parser.add_argument(
        '-s', '--save_model_list',
        type = str,
        help = "path to save the list of models that data was generated for"
    )
    parser.add_argument(
        '-g', '--git_revision',
        type = str,
        help = "git revision to generate test data for"
    )
    parser.add_argument(
        '--overwrite',
        action = 'store_true',
        help = "overwrite existing output data"
    )
    parser.add_argument(
        '-n', '--num_batches',
        default = 1,
        type = int,
        help = "amount of data batches to create (default: 1)"
    )
    parser.add_argument(
        '-b', '--batch_size',
        default = 1,
        type = int,
        help = "test data batch size (default: 1)"
    )
    args = parser.parse_args(args)
    model_list = []
    if args.model_list is not None:
        with open(args.model_list, 'r') as f:
            model_list = f.read().splitlines()
    if not args.all and len(args.model) < 1 and len(model_list) < 1:
        print("error: at least one model name is required")
        parser.print_help()
        parser.exit(1)
    if args.git_revision is not None:
        stash_output = subprocess.check_output(['git', 'stash']).decode().splitlines()
        has_stash = len(stash_output) > 0 and stash_output[0] != 'No local changes to save'
        current_branch = subprocess.check_output(['git', 'branch', '--show-current'])
        if len(current_branch) < 1:
            # not on a branch -> detached head
            current_branch = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
        current_branch = current_branch.splitlines()[0].decode()
        try:
            _sytem_assert(f'git checkout {args.git_revision}')
        except AssertionError as e:
            _sytem_assert(f'git checkout -f {current_branch}')
            if has_stash:
                os.system(f'git stash pop')
            raise e
    open_clip = importlib.import_module('open_clip')
    models = open_clip.list_models() if args.all else args.model + model_list
    try:
        models = create_test_data(
            models,
            batches = args.num_batches,
            batch_size = args.batch_size,
            overwrite = args.overwrite
        )
    finally:
        if args.git_revision is not None:
            test_dir = os.path.join(os.path.dirname(__file__), 'data')
            test_dir_ref = os.path.join(os.path.dirname(__file__), 'data_ref')
            if os.path.exists(test_dir_ref):
                shutil.rmtree(test_dir_ref, ignore_errors = True)
            if os.path.exists(test_dir):
                os.rename(test_dir, test_dir_ref)
            _sytem_assert(f'git checkout {current_branch}')
            if has_stash:
                os.system(f'git stash pop')
            os.rename(test_dir_ref, test_dir)
    if args.save_model_list is not None:
        print(f"Saving model list as {args.save_model_list}")
        with open(args.save_model_list, 'w') as f:
            for m in models:
                print(m, file=f)


if __name__ == '__main__':
    import sys
    main(sys.argv[1:])

