import random

from pytorch_lightning import seed_everything
from torchvision import transforms

from reward_helpers import *

VERSION = "vwm"
DATASET = "IMG"
SAVE_PATH = "outputs/"
NUM_ROUNDS = 1
NUM_FRAMES = 25
RANDOM_SEED = 23
TARGET_HEIGHT = 576
TARGET_WIDTH = 1024
CFG_SCALE = 2.5
COND_AUG = 0.0
STEPS = 10
RANDOM_GEN = True
LOWVRAM_MODE = False

VERSION2SPECS = {
    "vwm": {
        "config": "configs/inference/vista_action.yaml",
        "ckpt": "checkpoints/pytorch_model.bin"
    }
}

DATASET2SOURCES = {
    "IMG": {
        "data_root": "image_folder"
    }
}


def get_sample(selected_index=0):
    dataset_dict = DATASET2SOURCES[DATASET]
    image_list = os.listdir(dataset_dict["data_root"])
    total_length = len(image_list)
    while selected_index >= total_length:
        selected_index -= total_length
    image_file = image_list[selected_index]
    path_list = [os.path.join(dataset_dict["data_root"], image_file)] * NUM_FRAMES
    return path_list, selected_index, total_length


def load_img(file_name, target_height=320, target_width=576, device="cuda"):
    if file_name is not None:
        image = Image.open(file_name)
        if not image.mode == "RGB":
            image = image.convert("RGB")
    else:
        raise ValueError(f"Invalid image file {file_name}")
    ori_w, ori_h = image.size
    # print(f"Loaded input image of size ({ori_w}, {ori_h})")

    if ori_w / ori_h > target_width / target_height:
        tmp_w = int(target_width / target_height * ori_h)
        left = (ori_w - tmp_w) // 2
        right = (ori_w + tmp_w) // 2
        image = image.crop((left, 0, right, ori_h))
    elif ori_w / ori_h < target_width / target_height:
        tmp_h = int(target_height / target_width * ori_w)
        top = (ori_h - tmp_h) // 2
        bottom = (ori_h + tmp_h) // 2
        image = image.crop((0, top, ori_w, bottom))
    image = image.resize((target_width, target_height), resample=Image.LANCZOS)
    if not image.mode == "RGB":
        image = image.convert("RGB")
    image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2.0 - 1.0)
    ])(image)
    return image.to(device)


if __name__ == "__main__":
    set_lowvram_mode(LOWVRAM_MODE)
    version_dict = VERSION2SPECS[VERSION]
    model = init_model(version_dict)
    unique_keys = set([x.input_key for x in model.conditioner.embedders])

    sample_index = 0
    while sample_index >= 0:
        seed_everything(RANDOM_SEED)

        frame_list, sample_index, dataset_length = get_sample(sample_index)

        img_seq = list()
        for each_path in frame_list:
            img = load_img(each_path, TARGET_HEIGHT, TARGET_WIDTH)
            img_seq.append(img)
        images = torch.stack(img_seq)

        value_dict = init_embedder_options(unique_keys)
        cond_img = img_seq[0][None]
        value_dict["cond_frames_without_noise"] = cond_img
        value_dict["cond_aug"] = COND_AUG
        value_dict["cond_frames"] = cond_img + COND_AUG * torch.randn_like(cond_img)

        if NUM_ROUNDS > 1:
            guider = "TrianglePredictionGuider"
        else:
            guider = "VanillaCFG"
        sampler = init_sampling(guider=guider, steps=STEPS, cfg_scale=CFG_SCALE, num_frames=NUM_FRAMES)

        uc_keys = ["cond_frames", "cond_frames_without_noise", "command", "trajectory", "speed", "angle", "goal"]

        out = do_sample(
            images,
            model,
            sampler,
            value_dict,
            num_rounds=NUM_ROUNDS,
            num_frames=NUM_FRAMES,
            force_uc_zero_embeddings=uc_keys
        )

        if isinstance(out, (tuple, list)):
            samples, samples_z, inputs, reward = out
            save_path = SAVE_PATH
            perform_save_locally(save_path + "/virtual", samples, "images", DATASET, sample_index)
            perform_save_locally(save_path + "/virtual", samples, "grids", DATASET, sample_index)
            perform_save_locally(save_path + "/virtual", samples, "videos", DATASET, sample_index)
            perform_save_locally(save_path + "/real", inputs, "images", DATASET, sample_index)
            perform_save_locally(save_path + "/real", inputs, "grids", DATASET, sample_index)
            perform_save_locally(save_path + "/real", inputs, "videos", DATASET, sample_index)
        else:
            raise TypeError

        if RANDOM_GEN:
            sample_index += random.randint(1, dataset_length - 1)
        else:
            sample_index += 1
            if dataset_length <= sample_index:
                sample_index = -1
