import torch
import einops
import glob
import os
import json
import jsonlines
import tqdm
import hydra
import logging
import numpy as np
from PIL import Image
from omegaconf import DictConfig
from vqlm.vqvae_muse import get_tokenizer_muse

log = logging.getLogger(__name__)

def read_image_to_tensor(path, center_crop=1.0):
    pil_im = Image.open(path).convert('RGB')
    if center_crop < 1.0:
        width, height = pil_im.size
        pil_im = pil_im.crop((
            int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2),
            int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2),
        ))
    input_img = pil_im.resize((256, 256))
    input_img = np.array(input_img) / 255.0
    input_img = input_img.astype(np.float32)
    return input_img

class SFTDataset(torch.utils.data.Dataset):
    def __init__(self, input_files, target_files, input_states, metas):
        assert len(input_files)
        self.input_files = input_files
        self.target_files = target_files
        self.input_states = input_states
        self.metas = metas

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        original_size = Image.open(self.input_files[idx][-1]).size
        input_images = np.stack(
            [read_image_to_tensor(f) for f in self.input_files[idx]],
            axis=0
        )

        if self.target_files is not None:
            target_images = np.stack(
                [read_image_to_tensor(f) for f in self.target_files[idx]],
                axis=0
            )
        else:
            target_images = np.empty((0,))
        return input_images, target_images, self.input_states[idx], self.metas[idx], np.array(original_size)


def collect_data(base_dir, levels, is_random = False, training_size = -1):
    img_files, states_lists, meta_list = [], [], []

    for level in levels:
        num_folders = len(next(os.walk(os.path.join(base_dir, level)))[1])

        level_training_size = training_size if training_size != -1 else num_folders
        
        with open(os.path.join(base_dir, level, "data.json"), "r") as f:
            data = json.load(f)
        with open(os.path.join(base_dir, level, "data_distance_map.json"), "r") as f:
            data_distance_map = json.load(f)

        if is_random:
            for i in range(level_training_size):
                num_subfolders = len(next(os.walk(os.path.join(base_dir, level, str(i))))[1])
                for j in range(num_subfolders):
                    folder_path = os.path.join(base_dir, level, str(i), str(j))
                    jpg_files = sorted(
                        glob.glob(f"{folder_path}/*.jpg"),
                        key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
                    )
                    assert jpg_files, f"No jpg files found in {folder_path}"
                    
                    dict = data_distance_map[str(i)]
                    dict['level'] = int(level[-1])
                    img_files.append(jpg_files)
                    states_lists.append(data[str(i)]["states"][j])
                    meta_list.append(dict)
        else:
            for i in range(num_folders):
                folder_path = os.path.join(base_dir, level, str(i))
                jpg_files = sorted(
                        glob.glob(f"{folder_path}/*.jpg"),
                        key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
                    )
                assert jpg_files, f"No jpg files found in {folder_path}"
                
                dict = data_distance_map[str(i)]
                dict['level'] = int(level[-1])
                img_files.append(jpg_files)
                states_lists.append(data[str(i)]["states"])
                meta_list.append(dict)

    return img_files, states_lists, meta_list


def process_pairs(img_files, states_lists, meta_list, is_train = True, is_random = False):
    input_files, target_files, states, metas = [], [], [], []
    
    if is_train:
        if is_random:
            for img_list, state_list, meta in zip(img_files, states_lists, meta_list):
                input_files.extend([[img_list[0]]])
                target_files.extend([[img_list[1]]])
                states.extend([state_list])
                metas.extend([meta])
        else:
            for img_list, state_list, meta in zip(img_files, states_lists, meta_list):
                input_files.extend([[img_list[j]] for j in range(len(img_list) - 1)])
                target_files.extend([[img_list[j + 1]] for j in range(len(img_list) - 1)])
                states.extend(state_list[:-1])
                metas.extend([meta] * (len(img_list) - 1))
    else:
        if is_random:
            raise NotImplementedError
        else:
            for img_list, state_list, meta in zip(img_files, states_lists, meta_list):
                input_files.extend([[img_list[0]]])
                target_files.extend([[img_list[-1]]])
                states.extend([state_list])
                metas.extend([meta])

    return input_files, target_files, states, metas

def tokenize_random(cfg: DictConfig):
    base_dir = cfg.base_dir
    levels = cfg.levels
    training_size = cfg.training_size
    torch_device = cfg.torch_device

    all_img_files, all_states_lists, all_meta_list = collect_data(base_dir, levels, True, training_size)
    train_input_files, train_target_files, train_states, train_metas = process_pairs(all_img_files, all_states_lists, all_meta_list, is_train=True, is_random=True)
    train_set = SFTDataset(train_input_files, train_target_files, train_states, train_metas)

    tokenizer = get_tokenizer_muse().to(torch_device)

    os.makedirs(os.path.dirname(cfg.tokenized_train_set_pth), exist_ok=True)
    if os.path.exists(cfg.tokenized_train_set_pth):
        os.remove(cfg.tokenized_train_set_pth)

    for i in tqdm.tqdm(range(len(train_set)), desc="Encoding training images"):
        input_images, target_images, input_state, meta, _ = train_set[i]

        input_images = np.array(input_images[:], dtype=np.float32)
        input_images = torch.tensor(
            einops.rearrange(input_images, 'b h w c -> b c h w')
        ).to(torch_device)

        _, input_ids = tokenizer.encode(input_images)
        input_ids = input_ids.view(1, -1)

        if target_images.size > 0:
            target_images = np.array(target_images[:], dtype=np.float32)
            target_images = torch.tensor(
            einops.rearrange(target_images, 'b h w c -> b c h w')
            ).to(torch_device)

            _, target_ids = tokenizer.encode(target_images)
            target_ids = target_ids.view(1, -1)
        
        with jsonlines.open(cfg.tokenized_train_set_pth, mode='a') as writer:
            writer.write({
            'input_tokens': input_ids.view(-1).tolist(),
            'output_tokens': target_ids.view(-1).tolist(),
            'input_state' : input_state[0],
            'output_state' : input_state[1],
            'meta' : meta
            })

    num_lines = sum(1 for _ in open(cfg.tokenized_train_set_pth, 'r', encoding='utf-8'))
    log.info(f"Number of data in the tokenized training set: {num_lines}")

def tokenize_optimal(cfg: DictConfig):
    base_dir = cfg.base_dir
    levels = cfg.levels
    training_size = cfg.training_size
    torch_device = cfg.torch_device
    traj_per_level = 2
    testing_size = 2 - training_size

    all_img_files, all_states_lists, all_meta_list = collect_data(base_dir, levels)

    for index, level in enumerate(levels):
        if index == 0:
            train_img_files, test_img_files = all_img_files[:training_size], all_img_files[training_size:training_size+testing_size]
            train_states_lists, test_states_lists = all_states_lists[:training_size], all_states_lists[training_size:training_size+testing_size]
            train_meta_list, test_meta_list = all_meta_list[:training_size], all_meta_list[training_size:training_size+testing_size]
        else:
            train_img_files.extend(all_img_files[traj_per_level * index: traj_per_level * index + training_size])
            test_img_files.extend(all_img_files[traj_per_level * index + training_size: traj_per_level * index + training_size + testing_size])
            train_states_lists.extend(all_states_lists[traj_per_level * index: traj_per_level * index + training_size])
            test_states_lists.extend(all_states_lists[traj_per_level * index + training_size: traj_per_level * index + training_size + testing_size])
            train_meta_list.extend(all_meta_list[traj_per_level * index: traj_per_level * index + training_size])
            test_meta_list.extend(all_meta_list[traj_per_level * index + training_size: traj_per_level * index + training_size + testing_size])

    train_input_files, train_target_files, train_states, train_metas = process_pairs(train_img_files, train_states_lists, train_meta_list)
    test_input_files, test_target_files, test_states, test_metas = process_pairs(test_img_files, test_states_lists, test_meta_list, False)

    train_set = SFTDataset(train_input_files, train_target_files, train_states, train_metas)
    test_set = SFTDataset(test_input_files, test_target_files, test_states, test_metas)
    tokenizer = get_tokenizer_muse().to(torch_device)
    
    os.makedirs(os.path.dirname(cfg.tokenized_train_set_pth), exist_ok=True)
    if os.path.exists(cfg.tokenized_train_set_pth):
        os.remove(cfg.tokenized_train_set_pth)

    if os.path.exists(cfg.tokenized_test_set_pth):
        os.remove(cfg.tokenized_test_set_pth)

    for i in tqdm.tqdm(range(len(train_set)), desc="Encoding training images"):
        input_images, target_images, input_state, meta, _ = train_set[i]

        input_images = np.array(input_images[:], dtype=np.float32)
        input_images = torch.tensor(
            einops.rearrange(input_images, 'b h w c -> b c h w')
        ).to(torch_device)

        _, input_ids = tokenizer.encode(input_images)
        input_ids = input_ids.view(1, -1)

        if target_images.size > 0:
            target_images = np.array(target_images[:], dtype=np.float32)
            target_images = torch.tensor(
            einops.rearrange(target_images, 'b h w c -> b c h w')
            ).to(torch_device)

            _, target_ids = tokenizer.encode(target_images)
            target_ids = target_ids.view(1, -1)
        
        with jsonlines.open(cfg.tokenized_train_set_pth, mode='a') as writer:
            writer.write({
            'input_tokens': input_ids.view(-1).tolist(),
            'output_tokens': target_ids.view(-1).tolist(),
            'input_state' : input_state,
            'meta' : meta
            })

    for i in tqdm.tqdm(range(len(test_set)), desc="Encoding testing images"):
        input_images, target_images, input_state, meta, _ = test_set[i]
        input_images = np.array(input_images[:], dtype=np.float32)
        input_images = torch.tensor(
            einops.rearrange(input_images, 'b h w c -> b c h w')
        ).to(torch_device)

        _, input_ids = tokenizer.encode(input_images)
        input_ids = input_ids.view(1, -1)

        if target_images.size > 0:
            target_images = np.array(target_images[:], dtype=np.float32)
            target_images = torch.tensor(
            einops.rearrange(target_images, 'b h w c -> b c h w')
            ).to(torch_device)

            _, target_ids = tokenizer.encode(target_images)
            target_ids = target_ids.view(1, -1)
        
        with jsonlines.open(cfg.tokenized_test_set_pth, mode='a') as writer:
            writer.write({
            'input_tokens': input_ids.view(-1).tolist(),
            'output_tokens': target_ids.view(-1).tolist(),
            'input_state' : input_state,
            'meta' : meta
            })

    num_lines = sum(1 for _ in open(cfg.tokenized_train_set_pth, 'r', encoding='utf-8'))
    log.info(f"Number of data in the tokenized training set: {num_lines}")

    num_lines = sum(1 for _ in open(cfg.tokenized_test_set_pth, 'r', encoding='utf-8'))
    log.info(f"Number of data in the tokenized testing set: {num_lines}")

@hydra.main(config_path="configs", config_name="tokenize", version_base=None)
def main(cfg: DictConfig):
    log.info(f"Training Config: {cfg}")
    
    if cfg.is_random:
        tokenize_random(cfg)
    else:
        tokenize_optimal(cfg)
    

if __name__ == "__main__":
    main()