import tarfile
import pickle
import os
import copy
import threading
import random
import queue
import io

import gcsfs
import numpy as np
import decord
import zerofun
from ml_collections import ConfigDict

from PIL import Image

import jax
from jax.experimental.multihost_utils import host_local_array_to_global_array
from jax.sharding import PartitionSpec as PS
from tux import open_file


def shard_batch_to_global(batch, mesh):
    if isinstance(batch, dict):
        seq_length = batch[list(batch.keys())[0]].shape[1]
    elif isinstance(batch, (tuple, list)):
        seq_length = batch[0].shape[1]
    else:
        seq_length = batch.shape[1]
    sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
    sp_nodes_rank = jax.process_index() % sp_nodes_size
    assert seq_length % sp_nodes_size == 0, (seq_length, sp_nodes_size)
    seq_chunk_size = seq_length // sp_nodes_size
    batch = jax.tree_map(lambda x: x[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size], batch)
    batch = host_local_array_to_global_array(batch, mesh, PS(('dp', 'fsdp'), 'sp'))
    return batch


class DatasetFactory(object):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.type = 'vision_dataset'
        config.vision_dataset = VisionDataset.get_default_config()

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    @classmethod
    def load_dataset(cls, config, **kwargs):
        config = cls.get_default_config(config)
        if config.type == 'vision_dataset':
            return VisionDataset(config.vision_dataset, **kwargs)
        else:
            raise ValueError(f"Unknown dataset type: {config.type}")

    def __init__(self):
        raise ValueError('DatasetFactory is a static class and should not be instantiated.')


def _data_reader_worker(config, file_paths, read_queue, init):
    start_epoch, start_i, start_j = init
    def _read(file_path):
        if file_path.endswith('.tar'):
            with open_file(f"gs://{file_path}", 'rb') as f:
                data = f.read()
        else:
            data = file_path.encode('utf-8')
        return (data, file_path)
    epoch = start_epoch
    while True:
        file_paths_copy = copy.deepcopy(file_paths)
        random.Random(1234 + epoch).shuffle(file_paths_copy)
        for i in range(start_i, len(file_paths_copy)):
            data, file_path = _read(file_paths_copy[i])
            if file_path.endswith('.tar'):
                with tarfile.open(fileobj=io.BytesIO(data), mode='r') as tar:
                    for j, member in enumerate(tar.getmembers()):
                        if epoch == start_epoch and i == start_i and j <= start_j:
                            continue
                        if member.name.endswith('.jpg'):
                            member_data = tar.extractfile(member).read()
                            read_queue.put(((epoch, i, j), member_data, 'image'))
                        elif member.name.endswith('.mp4'):
                            member_data = tar.extractfile(member).read()
                            read_queue.put(((epoch, i, j), member_data, 'video'))
            else:
                read_queue.put(((epoch, i, 0), data, 'file'))
        epoch += 1


def _data_submit_worker(config, read_queue, futures_queue, port_numbers):
    clients = [
        zerofun.Client(f'tcp://localhost:{port_number}')
        for port_number in port_numbers
    ]
    [client.connect() for client in clients]
    idx = 0
    while True:
        out = read_queue.get()
        if out is None:
            break
        dataset_loc, data, file_type = out
        data = np.frombuffer(data, dtype=np.uint8)
        future = clients[idx % len(clients)].process({'dataset_loc': dataset_loc, 'data': data, 'file_type': file_type})
        futures_queue.put(future)
        idx += 1
    futures_queue.put(None)


def _data_comm_worker(futures_queue, data_queue):
    while True:
        future = futures_queue.get()
        if future is None:
            break
        data = future.result()
        dataset_loc = data.pop('dataset_loc')
        vision, shape = _post_receive(data)
        data_queue.put((dataset_loc, vision, shape))
    data_queue.put(None)


def _process_image(config, data_np):
    image_size = config.resolution
    block_size = config.elastic_config.frames_per_block
    Tp, Hp, Wp = config.elastic_config.patch_size
    image = Image.open(io.BytesIO(data_np.data)).convert('RGB')
    image = _image_resize_shortest(image, image_size)
    if config.process_square or random.random() < config.prob_square:
        image = _image_center_crop(image, (image_size, image_size))
    else:
        if config.limit_ar_subset:
            image = _image_crop_next_ar(
                  image, image_size, [9/16, 3/4, 1/1, 4/3, 16/9])
        image = _image_crop_patch_multiple(image, (Hp, Wp))

    ar = image.size[0] / image.size[1]
    if ar < config.min_aspect_ratio or ar > config.max_aspect_ratio:
        return [], (0, 0, 0)
    video = np.array(image)[None].repeat(block_size, axis=0)

    T, H, W, C = video.shape
    video = video.reshape(T // Tp, Tp, H // Hp, Hp, W // Wp, Wp, C)
    video = np.transpose(video, (0, 2, 4, 1, 3, 5, 6))
    video = video.reshape(-1, np.prod(config.elastic_config.patch_size) * C)
    return [video], (T, H, W)


def to_seed(data):
    hash_value = hash(data)
    seed = hash_value & 0xFFFFFFFF
    return seed


def _process_video(config, data_np, dataset_loc):
    Tp, Hp, Wp = config.elastic_config.patch_size
    block_size = config.elastic_config.frames_per_block
    image_size = config.resolution
    try:
        vr = decord.VideoReader(io.BytesIO(data_np.data))
    except Exception as e:
        print(f"Error: {e}")
        return [], (0, 0, 0)
    if len(vr) < block_size:
        return [], (0, 0, 0)

    one_frame = vr.get_batch([0]).asnumpy()
    if config.process_square:
        one_frame = _video_center_crop_np(one_frame, (block_size, image_size, image_size))
    else:
        one_frame = _video_crop_patch_multiple_np(one_frame, (1, Hp, Wp))
    _, H, W, C = one_frame.shape
    ar = W / H
    if ar < config.min_aspect_ratio or ar > config.max_aspect_ratio:
        return [], (0, 0, 0)

    n_blocks_per_seq = config.seq_length // config.elastic_config.max_toks
    n_blocks = len(vr) // block_size
    rng = np.random.RandomState(to_seed(tuple(map(lambda x: x.item(), dataset_loc))))
    output = []

    if config.use_synthetic:
        n_frames = len(vr)
        if n_frames < config.synthetic_frame_dist + 1:
            return [], (0, 0, 0)
        idx = rng.randint(0, n_frames - config.synthetic_frame_dist + 1)
        f1 = vr[idx].asnumpy()[None].repeat(block_size, axis=0)
        f2 = vr[idx + config.synthetic_frame_dist].asnumpy()[None].repeat(block_size, axis=0)
        for val in config.synthetic_pattern:
            if val == 'a':
                v = f1
            elif val == 'b':
                v = f2
            else:
                raise Exception(config.synthetic_pattern)
            if config.process_square:
                v = _video_center_crop_np(v, (block_size, image_size, image_size))
            else:
                v = _video_crop_patch_multiple_np(v, (Tp, Hp, Wp))
            v = v.reshape(block_size // Tp, Tp, H // Hp, Hp, W // Wp, Wp, C)
            v = np.transpose(v, (0, 2, 4, 1, 3, 5, 6))
            v = v.reshape(-1, np.prod(config.elastic_config.patch_size) * C)
            output.append(v)
    else:
        start_block, end_block = 0, n_blocks
        if config.one_seq_per_elem:
            if n_blocks < n_blocks_per_seq:
                return [], (0, 0, 0)
            start_block = rng.randint(0, n_blocks - n_blocks_per_seq+ 1)
            end_block = start_block + n_blocks_per_seq
            n_blocks = n_blocks_per_seq
        if config.n_block_skip > 0 and n_blocks >= n_blocks_per_seq + config.n_block_skip:
            start_block = rng.randint(0, config.n_block_skip + 1)

        block_idx = start_block
        while block_idx < end_block:
            for i in range(block_idx, min(block_idx + n_blocks_per_seq, end_block)):
                idxs = list(range(i * block_size, (i + 1) * block_size))
                v = vr.get_batch(idxs).asnumpy()
                if config.process_square:
                    v = _video_center_crop_np(v, (block_size, image_size, image_size))
                else:
                    v = _video_crop_patch_multiple_np(v, (Tp, Hp, Wp))
                v = v.reshape(block_size // Tp, Tp, H // Hp, Hp, W // Wp, Wp, C)
                v = np.transpose(v, (0, 2, 4, 1, 3, 5, 6))
                v = v.reshape(-1, np.prod(config.elastic_config.patch_size) * C)
                output.append(v)
            block_idx += n_blocks_per_seq + config.n_block_skip
    return output, (len(output) * block_size, H, W)


def _pre_send(data):
    vision_chunks, shape = data
    output = {}
    for i, v in enumerate(vision_chunks):
        output[f"vision_{i}"] = v
    output["shape"] = shape
    return output


def _post_receive(data):
    shape = data.pop('shape')
    vision = [data[f"vision_{i}"] for i in range(len(data))]
    return vision, shape


def _data_process_worker(config, port_number):
    def _process(data):
        dataset_loc = data['dataset_loc']
        data_np, file_type = data['data'], data['file_type']
        if file_type == 'image':
            out = _process_image(config, data_np)
        elif file_type == 'video':
            out = _process_video(config, data_np, dataset_loc)
        elif file_type == 'file':
            file_path = bytes(data_np.data).decode('utf-8')
            data_bytes = open_file(f"gs://{file_path}", 'rb').read()
            data_np = np.frombuffer(data_bytes, dtype=np.uint8)
            if file_path.endswith('.jpg'):
                out = _process_image(config, data_np)
            elif file_path.endswith('.mp4'):
                out = _process_video(config, data_np, dataset_loc)
            else:
                raise ValueError(f"Unsupported file type: {file_path}")
        else:
            raise ValueError(f"Unsupported file type: {file_type}")
        out = _pre_send(out)
        out['dataset_loc'] = dataset_loc
        return out
    server = zerofun.Server(f'tcp://*:{port_number}')
    server.bind('process', _process)
    server.run()


class MaskSampler:
    def __init__(self, config, rng):
        self.mask_type = config.elastic_config.mask_type
        self.min_toks = config.elastic_config.min_toks
        self.max_toks = config.elastic_config.max_toks
        self.rng = rng
        if self.mask_type.startswith('fixed'):
            self.threshold = float(self.mask_type.split('_')[1])
        elif self.mask_type == 'elastic':
            pass
        else:
            raise NotImplementedError(self.mask_type)

    def __call__(self, ntoks=None):
        if ntoks is None:
            if self.mask_type.startswith('fixed'):
                ntoks = int(np.ceil(self.max_toks * self.threshold))
            elif self.mask_type== 'elastic':
                ntoks = self.rng.randint(self.min_toks, self.max_toks)
            else:
                raise NotImplementedError(self.mask_type)
        encoding_mask = np.arange(self.max_toks) <= ntoks
        return encoding_mask


def _collater(config, node_info, batch_queue, data_queues, n_elems_per_dataset):
    assert len(data_queues) == len(n_elems_per_dataset), (len(data_queues), len(n_elems_per_dataset))
    batch_size = config.batch_size
    if config.use_data_sharded_loader:
        batch_size = batch_size // node_info['dp_node_size']
    block_size = config.elastic_config.max_toks
    n_blocks = batch_size * config.seq_length // block_size
    n_blocks_per_seq = config.seq_length // block_size
    node_offset = node_info['dp_node_rank'] * n_blocks * block_size
    rng = np.random.RandomState(node_info['dp_node_rank'])
    mask_sampler = MaskSampler(config, rng)

    n_datasets = len(n_elems_per_dataset)
    cache_per_dataset = [(None, [], None) for _ in range(n_datasets)]
    while True:
        batch = dict(
            vision=np.zeros((n_blocks, block_size, np.prod(config.elastic_config.patch_size) * 3), dtype=np.uint8),
            encoding_mask=np.ones((n_blocks, block_size), dtype=bool),
            segment_ids=np.zeros((n_blocks, block_size), dtype=np.int32),
            attention_mask=np.zeros((n_blocks, block_size), dtype=bool),
            position_ids=np.zeros((n_blocks, block_size), dtype=np.int32),
        )
        cur_idx, cur_segment_id, cur_block_id, = 0, 0, 0
        for j, (data_queue, sub_batch_size) in enumerate(zip(data_queues, n_elems_per_dataset)):
            if sub_batch_size == 0:
              continue
            n_sub_blocks = sub_batch_size * config.seq_length // block_size
            end_idx = cur_idx + n_sub_blocks
            cur_dataset_loc, cur_vision, cur_shape = cache_per_dataset[j]
            while cur_idx < end_idx:
                if len(cur_vision) == 0:
                    cur_dataset_loc, cur_vision, cur_shape = data_queue.get()
                    if len(cur_vision) == 0:
                        continue
                    cur_segment_id += 1
                    cur_block_id = 0
                if cur_idx % n_blocks_per_seq == 0:
                    cur_block_id = 0
                vision_elem = cur_vision.pop(0)
                batch['vision'][cur_idx] = vision_elem
                batch['encoding_mask'][cur_idx] = mask_sampler()
                batch['segment_ids'][cur_idx] = cur_segment_id
                batch['position_ids'][cur_idx] = np.arange(block_size, dtype=np.int32) + cur_block_id * block_size
                batch['attention_mask'][cur_idx] = True

                cur_idx += 1
                cur_block_id += 1
            n_blocks_left = len(cur_vision)
            cur_shape = (n_blocks_left * config.elastic_config.frames_per_block, *cur_shape[1:])
            cache_per_dataset[j] = (cur_dataset_loc, cur_vision, cur_shape)
        assert cur_idx == n_blocks, (cur_idx, n_blocks)
        batch = jax.tree_map(
            lambda x: x.reshape((batch_size, config.seq_length, *x.shape[2:])),
            batch)
        dataset_locs = tuple([c[0] for c in cache_per_dataset])
        batch_queue.put((dataset_locs, batch))


def _video_center_crop_np(video, crop_size):
    T, H, W, _ = video.shape
    target_T, target_H, target_W = crop_size
    start = (T - target_T) // 2
    top = (H - target_H) // 2
    left = (W - target_W) // 2
    end = start + target_T
    bottom = top + target_H
    right = left + target_W
    return video[start:end, top:bottom, left:right, :]


def _video_crop_patch_multiple_np(video, patch_size):
    T, H, W, _ = video.shape
    Tp, Hp, Wp = patch_size
    target_T, target_H, target_W = T // Tp * Tp, H // Hp * Hp, W // Wp * Wp
    video = _video_center_crop_np(video, (target_T, target_H, target_W))
    return video


def _image_center_crop(image, crop_size):
    W, H = image.size
    left = (W - crop_size[0]) // 2
    top = (H - crop_size[1]) // 2
    right = left + crop_size[0]
    bottom = top + crop_size[1]
    return image.crop((left, top, right, bottom))


def _image_crop_patch_multiple(image, patch_size):
    W, H = image.size
    Hp, Wp = patch_size
    target_W, target_H = W // Wp * Wp, H // Hp * Hp
    image = _image_center_crop(image, (target_W, target_H))
    return image


def _image_crop_next_ar(image, shortest_side_res, aspect_ratios):
    W, H = image.size
    if H != shortest_side_res:
        assert H > shortest_side_res, H
        assert W == shortest_side_res, image.size
        h_vals = sorted([int(shortest_side_res / ar) for ar in aspect_ratios])[::-1]
        for i, h_val in enumerate(h_vals):
            if h_val <= H:
                break
        else:
            raise Exception(image.size)
        image = _image_center_crop(image, (shortest_side_res, h_val))
        assert image.size == (shortest_side_res, h_val), (image.size, (shortest_side_res, h_val))
    elif W != shortest_side_res:
        assert W > shortest_side_res, W
        assert H == shortest_side_res, image.size
        w_vals = sorted([int(shortest_side_res * ar) for ar in aspect_ratios])[::-1]
        for i, w_val in enumerate(w_vals):
            if w_val <= W:
                break
        else:
            raise Exception(image.size)
        image = _image_center_crop(image, (w_val, shortest_side_res))
        assert image.size == (w_val, shortest_side_res), (image.size, (w_val, shortest_side_res))
    else: # Square
        pass
    return image



def _image_resize_shortest(image, image_size):
    W, H = image.size
    if W > H:
        ratio = image_size / H
        target_W = int(W * ratio)
        if target_W == image_size - 1:
            target_W = image_size
        target_H = image_size
    else:
        ratio = image_size / W
        target_W = image_size
        target_H = int(H * ratio)
        if target_H == image_size - 1:
            target_H = image_size
    image = image.resize((target_W, target_H), Image.BILINEAR)
    return image


class VisionDataset(object):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        # Arguments below require to be comma-separated lists (as str) of equal length
        config.paths = ''
        config.extensions = ''
        config.batch_split_ratios=''
        config.data_process_workers = '4'
        config.max_read_queue_size = '4'
        config.max_futures_queue_size = '4'
        config.max_data_queue_size = '4'

        # Other arguments
        config.batch_size = 4
        config.resolution = 256
        config.seq_length = 8192
        config.max_batch_queue_size = 4
        config.one_seq_per_elem = False
        config.n_block_skip = 0
        config.use_data_sharded_loader = True
        config.min_aspect_ratio = 0.5
        config.max_aspect_ratio = 2.0
        config.process_square = False
        config.port = 2222
        config.seed = 1234

        # Synthetic Videos
        config.use_synthetic = False
        config.synthetic_pattern = ''
        config.synthetic_frame_dist = 1

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, node_info, mesh, elastic_config):
        self.config = self.get_default_config(config)
        assert self.config.paths != ''
        assert self.config.extensions != ''
        assert self.config.batch_split_ratios != ''
        assert self.config.use_data_sharded_loader, "Only sharded loader is supported."
        ratio_sum = sum(map(float, self.config.batch_split_ratios.split(',')))
        assert np.allclose(ratio_sum, 1.0), ratio_sum
        self.config.elastic_config = elastic_config
        self.mesh = mesh
        self.node_info = node_info

        if self.config.use_synthetic:
            n_blocks = self.config.seq_length // self.config.elastic_config.max_toks
            assert len(self.config.synthetic_pattern) == n_blocks, self.config.synthetic_pattern

        self._n_datasets = len(self.config.paths.split(','))
        self._dataset_locs = tuple([(0, 0, 0) for _ in range(self._n_datasets)])

    def __iter__(self):
        batch_sizes = [
            int(float(x) * self.config.batch_size)
            for x in self.config.batch_split_ratios.split(',')
        ]
        leftover = self.config.batch_size - sum(batch_sizes)
        batch_sizes = [b + (i < leftover) for i, b in enumerate(batch_sizes)]

        dp_size = self.node_info['dp_node_size']
        assert self.config.batch_size % dp_size == 0, (self.config.batch_size, dp_size)
        leftover = [batch_sizes[j] % dp_size for j in range(self._n_datasets)]
        n_elems_per_dataset = [
            [batch_sizes[j] // dp_size for j in range(self._n_datasets)]
            for _ in range(dp_size)
        ]
        dp_idx = 0
        for j in range(self._n_datasets):
            for _ in range(leftover[j]):
                n_elems_per_dataset[dp_idx][j] += 1
                dp_idx = (dp_idx + 1) % dp_size
        assert all([sum(n_elems) == self.config.batch_size // dp_size for n_elems in n_elems_per_dataset]), n_elems_per_dataset
        assert sum(map(sum, n_elems_per_dataset)) == self.config.batch_size, sum(map(sum, n_elems_per_dataset))

        proportions = []
        for i in range(dp_size):
            proportions.append([n_elems_per_dataset[i][j] / batch_sizes[j]
                                for j in range(self._n_datasets)])
        prop_ranges = []
        dp_rank = self.node_info['dp_node_rank']
        for j in range(self._n_datasets):
            start = sum([proportions[i][j] for i in range(dp_rank)])
            end = start + proportions[dp_rank][j]
            prop_ranges.append((start, end))
        n_elems_per_dataset = n_elems_per_dataset[dp_rank]

        loader_infos = []
        for j in range(self._n_datasets):
            if n_elems_per_dataset[j] == 0:
                loader_infos.append({'data_queue': None})
            else:
                loader_infos.append(self._init_data(j, prop_ranges[j]))
        data_queues = [info['data_queue'] for info in loader_infos]
        batch_queue = queue.Queue(maxsize=self.config.max_batch_queue_size)
        collate_worker = threading.Thread(target=_collater, args=(self.config, self.node_info, batch_queue, data_queues, n_elems_per_dataset))
        collate_worker.start()
        while True:
            self._dataset_locs, batch = batch_queue.get()
            if self.config.use_data_sharded_loader:
                batch = shard_batch_to_global(batch, self.mesh)
            yield batch

    def _init_data(self, idx, prop_range):
        def extract(elem):
            elem = elem.split(',')[idx]
            try:
                return int(elem)
            except:
                pass
            try:
                return float(elem)
            except:
                pass
            return elem

        fs = gcsfs.GCSFileSystem()
        path = extract(self.config.paths)
        if path.endswith('.pkl'):
            file_paths = pickle.load(fs.open(path, 'rb'))
        else:
            ext = extract(self.config.extensions)
            file_paths = fs.glob(os.path.join(path, f'*.{ext}'))
            if len(file_paths) == 0:
                file_paths = fs.glob(os.path.join(path, '**', f'*.{ext}') )
            file_paths.sort()
        random.Random(self.config.seed).shuffle(file_paths)
        print(f"Found {len(file_paths)} files in {path}")

        start_prop, end_prop = prop_range
        start_idx, end_idx = int(start_prop * len(file_paths)), int(end_prop * len(file_paths))
        file_paths = file_paths[start_idx:end_idx]
        print(f"Local file_paths chunk: {len(file_paths)}")

        read_queue = queue.Queue(maxsize=extract(self.config.max_read_queue_size))
        future_queue = queue.Queue(maxsize=extract(self.config.max_futures_queue_size))
        data_queue = queue.Queue(maxsize=extract(self.config.max_data_queue_size))

        read_worker = threading.Thread(target=_data_reader_worker, args=(self.config, file_paths, read_queue, self._dataset_locs[idx]))
        read_worker.start()

        n_workers = list(map(int, self.config.data_process_workers.split(',')))
        start, end = sum(n_workers[:idx]), sum(n_workers[:idx+1])
        port_numbers = list(range(self.config.port + start, self.config.port + end))
        submit_worker = threading.Thread(target=_data_submit_worker, args=(self.config, read_queue, future_queue, port_numbers))
        submit_worker.start()

        comm_worker = threading.Thread(target=_data_comm_worker, args=(future_queue, data_queue))
        comm_worker.start()

        proc_workers = [
            zerofun.Process(_data_process_worker, self.config, port_numbers[i], start=True)
            for i in range(extract(self.config.data_process_workers))
        ]

        return {
            'read_queue': read_queue,
            'future_queue': future_queue,
            'data_queue': data_queue,
            'read_worker': read_worker,
            'submit_worker': submit_worker,
            'comm_worker': comm_worker,
            'proc_workers': proc_workers,
        }

    def get_state_dict(self):
        return dict(
            config=self.config,
            dataset_locs=self._dataset_locs,
        )

    def load_state_dict(self, state_dict):
        if 'config' in state_dict:
            self.config.update(ConfigDict(state_dict['config']))
        self._dataset_locs = state_dict.get('dataset_locs', (0, 0, 0))

