import json
import os
import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
from PIL import Image

import math
import numpy as np
import torch
from accelerate.logging import get_logger
from torch.utils.data import Dataset, DistributedSampler, Sampler
from torchvision import transforms


# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord  # isort:skip

decord.bridge.set_bridge("torch")
from .dataset_image_video import get_random_mask

_COMMON_BEGINNING_PHRASES = (
    "This video",
    "The video",
    "This clip",
    "The clip",
    "The animation",
    "This image",
    "The image",
    "This picture",
    "The picture",
)
_COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents")

COMMON_LLM_START_PHRASES = (
    "In the video,",
    "In this video,",
    "In this video clip,",
    "In the clip,",
    "Caption:",
    *(
        f"{beginning} {continuation}"
        for beginning in _COMMON_BEGINNING_PHRASES
        for continuation in _COMMON_CONTINUATION_WORDS
    ),
)

logger = get_logger(__name__)

def resize_and_crop_image(image, height, width, return_pt=False):
    if isinstance(image, (str, Path)):
        image = Image.open(image).convert("RGB")
    w, h = image.size
    scale_ratio = max(height / h, width / w)
    resize_to_height, resize_to_width = math.ceil(h * scale_ratio), math.ceil(w * scale_ratio)
    image = image.resize((resize_to_width, resize_to_height))
    x1 = int(round((resize_to_height - height) / 2.))
    y1 = int(round((resize_to_width - width) / 2.))
    image = image.crop((y1, x1, y1 + width, x1 + height))
    if return_pt:
        image = torch.tensor(np.array(image).transpose(2, 0, 1), dtype=torch.float32)
        image = image / 127.5 - 1.0  # normalize to [-1, 1]
        return image  # (C, H, W)
    return image

class BaseDataset(Dataset):
    def __init__(
        self,
        data_root: str,
        dataset_file: str,
        caption_column: str,
        video_column: str,
        resolution_buckets: List[Tuple[int, int, int]],  # f h w
        bucket_sample_strategy: str = "random",  # or "longest"
        conditions: Optional[str] = "",
        id_token: Optional[str] = None,
        random_start: bool = True,
        remove_llm_prefixes: bool = True,
        frame_interval: Tuple[int, List[int]] = 1,
        ref_pose_drop_prompt_prob: float = 0.5,
    ) -> None:
        super().__init__()

        if data_root:
            self.data_root = Path(data_root)
        self.dataset_file = [dataset_file] if ',' not in dataset_file else dataset_file.split(',')
        self.caption_column = caption_column if ',' not in caption_column else caption_column.split(',')
        self.video_column = video_column if ',' not in video_column else video_column.split(',')
        self.id_token = f"{id_token.strip()} " if id_token else ""
        if isinstance(resolution_buckets[0], str):
            resolution_buckets = [eval(bucket) for bucket in resolution_buckets]
        self.resolution_buckets = resolution_buckets
        self.bucket_sample_strategy = bucket_sample_strategy
        self.random_start = random_start
        if isinstance(frame_interval, str):
            frame_interval = eval(frame_interval)
        self.frame_interval = frame_interval
        self.ref_pose_drop_prompt_prob = ref_pose_drop_prompt_prob
        self.conditions = conditions
        
        self.prompts, self.video_paths, self.examples = [], [], []
        for i, fn in enumerate(self.dataset_file):
            frame = frame_interval[i] if isinstance(frame_interval, (list, tuple)) else frame_interval
            self._load_dataset_from_jsonl(fn, frame)

        if len(self.video_paths) != len(self.prompts):
            raise ValueError(
                f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
            )

        # Clean LLM start phrases
        if remove_llm_prefixes:
            for i in range(len(self.prompts)):
                self.prompts[i] = self.prompts[i].strip()
                for phrase in COMMON_LLM_START_PHRASES:
                    if self.prompts[i].startswith(phrase):
                        self.prompts[i] = self.prompts[i].removeprefix(phrase).strip()

        self.video_transforms = transforms.Compose(
            [
                transforms.Lambda(self.scale_transform),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            ]
        )

    @staticmethod
    def scale_transform(x):
        return x / 255.0

    def __len__(self) -> int:
        return len(self.video_paths)
    
    def get_caption(self, example):
        if isinstance(self.caption_column, str):
            return example[self.caption_column]
        elif isinstance(self.caption_column, (list, tuple)):
            captions = []
            for column in self.caption_column:
                if column in example and example[column]:
                    captions.append(example[column])
            if captions:
                return random.choice(captions)
            else:
                return ""
        return ""

    def get_video(self, example):
        if isinstance(self.video_column, str):
            return Path(example[self.video_column])
        elif isinstance(self.video_column, (list, tuple)):
            for column in self.video_column:
                if column in example:
                    return Path(example[column])
        return None

    def _load_dataset_from_jsonl(self, fn, frame_interval) -> Tuple[List[str], List[str]]:
        with open(fn, "r", encoding="utf-8") as file:
            data = [json.loads(line) for line in file]

        num_origin = len(data)
        min_frames = min([bucket[0] for bucket in self.resolution_buckets])
        if min_frames > 1:
            min_frames = min_frames * frame_interval * 0.9
        data = [line for line in data if line.get('frame_count', line.get('length', 1)) >= min_frames]
        print(f'Reading {fn}: Keep {len(data)} samples, filter out {num_origin - len(data)} samples with less than {min_frames} frames')
        for example in data:
            example['frame_interval'] = frame_interval

        prompts = [self.get_caption(entry) for entry in data]
        video_paths = [self.get_video(entry) for entry in data]

        if any(not path.is_file() for path in video_paths):
            raise ValueError(
                f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
            )

        self.prompts += prompts
        self.video_paths += video_paths
        self.examples += data

        return prompts, video_paths, data


class ImageVideoDataset(BaseDataset):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
        self.buckets = defaultdict(list)
        random.seed(123)  # 确保各个进程分桶结果一致
        for index, example in enumerate(self.examples):
            width, height, length = example['width'], example['height'], example.get('frame_count', example.get('length', 1))
            frame_interval = example['frame_interval']
            close_bucket = self._find_nearest_bucket(height, width, length, frame_interval)
            has_ref = 'ref' in self.conditions and 'ref_images' in example
            has_conditions = 'pose' in self.conditions and 'conditions' in example
            close_bucket = close_bucket + (has_ref, has_conditions)
            self.buckets[close_bucket].append(index)
            example['bucket'] = close_bucket
        for bucket in self.buckets:
            logger.info(f'Bucket {bucket}: {len(self.buckets[bucket])} samples')

    def __getitem__(self, index: int) -> Dict[str, Any]:
        example = self.examples[index]
        prompt = self.id_token + self.get_caption(example)
        nearest_bucket = example['bucket']
        current_size = example.get('frame_count', example.get('length', 1)), example['height'], example['width']
        target_frame, target_height, target_width, has_ref, has_condition = nearest_bucket
        frame_interval = example['frame_interval']

        video_path: Path = self.video_paths[index]
        if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
            video = self._preprocess_image(video_path, target_height, target_width)  # (c h w)
            video = video.unsqueeze(0)  # (1 c h w)
        else:
            video, start_frame = self._preprocess_video(video_path, current_size, nearest_bucket, frame_interval)  # (f c h w), [-1,1]

        # ref_images = None
        # if has_ref and 'ref_images' in example:
        #     if example['ref_images'] == 'random':  # random select a frame from the video
        #         ref_index = random.randint(0, len(video) - 1)
        #         ref_images = video[ref_index].unsqueeze(1)  # (c 1 h w)
        #     else:
        #         if isinstance(example['ref_images'], (list, tuple)):
        #             ref_images = random.choice(example['ref_images'])
        #         else:
        #             ref_images = example['ref_images']
        #         ref_images = torch.cat(
        #             [self._preprocess_image(img, target_height, target_width).unsqueeze(1) for img in ref_images]
        #         , dim=1)  # (c #ref h w)
        
        # first_image = None
        # if 'first_image' in example:
        #     first_image = self._preprocess_image(example['first_image'], target_height, target_width)  # (c h w)
        
        # last_image = None
        # if 'last_image' in example:
        #     last_image = self._preprocess_image(example['last_image'], target_height, target_width)  # (c h w)
        
        control = None
        if 'control' in example:
            if example['control'].endswith('.png') or example['control'].endswith('.jpg') or example['control'].endswith('.jpeg'):
                control = self._preprocess_image(example['control'], target_height, target_width)  # (c h w)
                control = control.unsqueeze(0)  # (1 c h w)
            else:  # NOTE: conditions must be the same position as video
                control, _ = self._preprocess_video(example['control'], current_size, nearest_bucket, frame_interval, start_frame=start_frame)  # (f c h w)
                # print(f"start_frame: {start_frame}, condition start_frame: {_}")

        # if has_ref and has_condition and random.random() < self.ref_pose_drop_prompt_prob:  # prompt dropout
        #     prompt = ''
        
        sample = {}
        sample['pixel_values'] = video
        sample['text'] = prompt
        sample['video_metadata'] = str(self.video_paths[index])

        mask = get_random_mask(video.shape, image_start_only=True)
        mask_pixel_values = video * (1 - mask) + torch.ones_like(video) * -1 * mask
        sample["mask_pixel_values"] = mask_pixel_values  # [-1,1] (f c h w)
        sample["mask"] = mask

        clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
        clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
        sample["clip_pixel_values"] = clip_pixel_values  # [0,255] (h w c)

        ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
        if (mask == 1).all():
            ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
        sample["ref_pixel_values"] = ref_pixel_values  # [-1,1] (1 c h w)

        if control is not None:
            sample["control_pixel_values"] = control  # [-1,1] (f c h w)
        else:
            sample["control_pixel_values"] = torch.ones_like(video) * -1  # black control
        return sample

    def _preprocess_image(self, path: Path, target_height: int, target_width: int) -> torch.Tensor:
        # TODO(aryan): Support alpha channel in future by whitening background
        image = resize_and_crop_image(path, target_height, target_width)
        image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0  # (c h w)
        image = image * 2.0 - 1.0  # [-1,1]
        return image

    def _preprocess_video(self, path, current_size, target_size, frame_interval, start_frame=None) -> torch.Tensor:
        length, height, width = current_size
        target_frame, target_height, target_width, _, _ = target_size
        
        scale_ratio = max(target_height / height, target_width / width)
        resize_to_height, resize_to_width = math.ceil(height * scale_ratio), math.ceil(width * scale_ratio)

        video_reader = decord.VideoReader(uri=Path(path).as_posix(), width=resize_to_width, height=resize_to_height)
        if start_frame is None:
            if self.random_start and length > target_frame * frame_interval:
                start_frame = random.choice(list(range(length - target_frame * frame_interval + 1)))
            else:
                start_frame = 0
        frame_indices = list(range(start_frame, len(video_reader), frame_interval))
        frame_indices = frame_indices[:target_frame]
        # print(f"frame_indices: {frame_indices}")

        frames = video_reader.get_batch(frame_indices).float()  # (f h w c)

        # crop to bucket size
        x1 = int(round((resize_to_height - target_height) / 2.))
        y1 = int(round((resize_to_width - target_width) / 2.))
        frames = frames[:, x1:x1+target_height, y1:y1+target_width, :]
        # if random.random() < 0.1:
        #     export_to_video(frames.numpy()/255, '/home4/jiaxin/exp/test_ltx/processed_video.mp4', fps=25)
        frames = frames.permute(0, 3, 1, 2).contiguous()  # (f c h w)
        frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)  # [-1,1]

        # padding to bucket length with the last frame
        if len(frames) < target_frame:
            padding = frames[-1].unsqueeze(0).repeat(target_frame - len(frames), 1, 1, 1)
            frames = torch.cat([frames, padding], dim=0)  # (f c h w)

        return frames, start_frame

    def _find_nearest_bucket(self, height, width, length, frame_interval):
        candidate_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= (length//frame_interval+1)]  # 72 frames can be put into 73-frame bucket
        if len(candidate_buckets) == 0:  # use the shortest buckets, avoid padding too many frame
            candidate_buckets = sorted(self.resolution_buckets, key=lambda x: x[0])[0]
            candidate_buckets = [b for b in self.resolution_buckets if b[0] == candidate_buckets[0]]
        sorted_buckets = sorted(candidate_buckets, key=lambda x: (abs(x[1]/x[2]-height/width), abs(x[1]-height)+abs(x[2]-width)))
        nearest_buckets = []
        for b in sorted_buckets:  # multiple buckets may have the same resolution
            if length > 1 and b[0] == 1:
                continue  # skip 1-frame bucket for multi-frame video
            if b[1] == sorted_buckets[0][1] and b[2] == sorted_buckets[0][2]:
                nearest_buckets.append(b)
        # print(f"length: {length}, height: {height}, width: {width}, candidate_buckets: {candidate_buckets}, nearest_buckets: {nearest_buckets}")
        if len(nearest_buckets) == 0:
            nearest_buckets.append(sorted_buckets[0])
        
        if self.bucket_sample_strategy == "longest":
            # choose the longest bucket
            nearest_bucket = sorted(nearest_buckets, key=lambda x: x[0])[-1]
        else:
            # choose a random bucket from the nearest buckets
            nearest_bucket = random.choice(nearest_buckets)
        return nearest_bucket


class BucketDistributedSampler(DistributedSampler):
    '''
    https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler

    Re-arrange indices to make sure each step (batch*replica) gets from the same bucket
    '''
    def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, seed=0, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        self.shuffle = shuffle
        self.seed = seed

        self.sampler = None

        self.num_sample_per_step = self.num_replicas * self.batch_size  # 每一步实际消耗的样本量

        self.num_samples_per_bucket = {}
        self.total_size = 0
        for b, indices_bucket in self.dataset.buckets.items():
            if self.drop_last:
                if len(indices_bucket) % self.num_sample_per_step != 0:
                    self.num_samples_per_bucket[b] = math.ceil((len(indices_bucket) - self.num_sample_per_step) / self.num_sample_per_step) * self.num_sample_per_step
            else:
                self.num_samples_per_bucket[b] = math.ceil(len(indices_bucket) / self.num_sample_per_step) * self.num_sample_per_step
            self.total_size += self.num_samples_per_bucket[b]

        assert self.total_size % self.num_sample_per_step == 0
        self.num_steps = self.total_size // self.num_sample_per_step
    
    def __iter__(self):
        g = None
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
        
        indices = []
        for b, indices_bucket in self.dataset.buckets.items():
            if self.shuffle:
                in_indices = torch.randperm(len(indices_bucket), generator=g).tolist()
            else:
                in_indices = list(range(len(indices_bucket)))
            
            if not self.drop_last:
                # add extra samples to make it evenly divisible
                padding_size = self.num_samples_per_bucket[b] - len(in_indices)
                if padding_size <= len(in_indices):
                    in_indices += in_indices[:padding_size]
                else:
                    in_indices += (in_indices * math.ceil(padding_size / len(in_indices)))[:padding_size]
            else:
                # remove tail of data to make it evenly divisible
                in_indices = in_indices[:self.num_samples_per_bucket[b]]
            assert len(in_indices) == self.num_samples_per_bucket[b]

            # merge buckets
            indices += np.array(indices_bucket)[in_indices].tolist()
        
        # split into steps
        if self.shuffle:
            step_indices = torch.randperm(self.num_steps, generator=g).tolist()
        else:
            step_indices = list(range(self.num_steps))
        
        for step_i in step_indices:
            start = self.num_replicas * step_i * self.batch_size + self.rank * self.batch_size
            batch = indices[start: start + self.batch_size]
            yield batch
    
    def __len__(self):
        return self.num_steps


def cycle(dl, batch_sampler=None):
    epoch = 0
    while True:
        if batch_sampler is not None:
            batch_sampler.set_epoch(epoch)
        for data in dl:
            yield data
        epoch += 1


class DistributedKRepeatSampler(Sampler):
    def __init__(self, dataset, batch_size, k, num_replicas, rank, seed=0):
        self.dataset = dataset
        self.batch_size = batch_size  # 每卡的batch大小
        self.k = k                    # 每个样本重复的次数
        self.num_replicas = num_replicas  # 总卡数
        self.rank = rank              # 当前卡编号
        self.seed = seed              # 随机种子，用于同步
        self.index = 0
        # 计算每个迭代需要的不同样本数
        self.total_samples = self.num_replicas * self.batch_size
        assert self.total_samples % self.k == 0, f"k can not div n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}"
        self.m = self.total_samples // self.k  # 不同样本数
        self.epoch=0

    def __iter__(self):
        while True:
            # 生成确定性的随机序列，确保所有卡同步
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            # print('epoch', self.epoch)
            # 随机选择m个不同的样本
            # indices = torch.randperm(len(self.dataset), generator=g)[:self.m].tolist()
            indices = list(range(len(self.dataset)))[self.index:self.index+self.m]
            self.index += self.m
            # print(f"########Rank {self.rank}, Epoch {self.epoch}, Selected indices: {indices}")
            # print(self.rank, 'indices', indices)
            # 每个样本重复k次，生成总样本数n*b
            repeated_indices = [idx for idx in indices for _ in range(self.k)]
            
            # 打乱顺序确保均匀分配
            shuffled_indices = torch.randperm(len(repeated_indices), generator=g).tolist()
            shuffled_samples = [repeated_indices[i] for i in shuffled_indices]
            # print(self.rank, 'shuffled_samples', shuffled_samples)
            # 将样本分割到各个卡
            per_card_samples = []
            for i in range(self.num_replicas):
                start = i * self.batch_size
                end = start + self.batch_size
                per_card_samples.append(shuffled_samples[start:end])
            # print(self.rank, 'per_card_samples', per_card_samples[self.rank])
            # 返回当前卡的样本索引
            yield per_card_samples[self.rank]
    
    def set_epoch(self, epoch):
        self.epoch = epoch  # 用于同步不同 epoch 的随机状态


def collate_fn(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
    item = {
        "prompts": [x["prompt"] for x in batch],
    }
    if 'video_metadata' in batch[0]:
        item['video_metadata'] = batch[0]['video_metadata']
    if 'video' in batch[0]:
        item['videos'] = torch.stack([b["video"] for b in batch])
    if 'ref_images' in batch[0]:
        item["ref_images"] = torch.stack([b["ref_images"] for b in batch])
    if 'first_image' in batch[0]:
        item["first_image"] = torch.stack([b["first_image"] for b in batch])
    if 'last_image' in batch[0]:
        item["last_image"] = torch.stack([b["last_image"] for b in batch])
    if 'conditions' in batch[0]:
        item["conditions"] = torch.stack([b["conditions"] for b in batch])
    return item