import os
import torch
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, Subset, get_worker_info
from torchvision import datasets
from utils.general_utils import PILtoTorch
from PIL import Image
import numpy as np
from loguru import logger
from multiprocessing import Manager, Lock
import cv2
import pickle
import lmdb
import io
# # 创建共享缓存（在主进程中完成）
# manager = Manager()
# shared_cache = manager.dict()
# cache_lock = Lock()

class LMDBCameraDataset(Dataset):
    def __init__(self, lmdb_path, viewpoint_stack, white_background=True, target_time=[-1]):
        self.lmdb_path = lmdb_path
        self.bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])
        self.target_time = target_time
        self.viewpoint_stack = self.filter_viewpoints(viewpoint_stack, target_time)
        self.true_viewpoint_idx = [idx for idx, v in enumerate(self.viewpoint_stack) if v.is_true_image]
        self.env = None  # lazy init per worker
        self.key_map = None  # will be loaded inside worker

    def __getstate__(self):
        '''
        在使用多进程 DataLoader 时避免 pickle 导致 env/txn 被复制到子进程中
        '''
        state = self.__dict__.copy()
        state["env"] = None
        state["txn"] = None
        return state

    def _init_env(self):
        if self.env is None:
            self.env = lmdb.open(self.lmdb_path, readonly=True, lock=False, readahead=False, meminit=False)
            self.txn = self.env.begin(buffers=True)

            self.keys = pickle.loads(self.txn.get(b"__keys__"))  # List[(key, filename)]
            self.key_map = {name: key for key, name in self.keys}

    def filter_viewpoints(self, viewpoint_stack, target_time):
        try:
            if len(target_time) == 1 and target_time[0] < 0:
                logger.info(f"Load ALL Viewpoints.")
                return viewpoint_stack
            elif len(target_time) == 2:
                start_ts, end_ts = target_time
                tss = np.array(list(set([v.timestamp for v in viewpoint_stack])))
                logger.info(f"max timestamp = {tss.max()}, target time duration: {target_time}")
                return [vp for vp in viewpoint_stack if start_ts < vp.timestamp < end_ts or np.isclose([vp.timestamp], [start_ts, end_ts]).any()]
            else:
                tss = np.array(list(set([v.timestamp for v in viewpoint_stack])))
                logger.info(f"max timestamp = {tss.max()}, target time: {target_time}")
                return [vp for vp in viewpoint_stack if any(np.isclose(vp.timestamp, tt) for tt in target_time)]
        except:
            logger.warning(f"Invalid target time {target_time}, load all viewpoint")
            return viewpoint_stack

    def __len__(self):
        return len(self.viewpoint_stack)

    def __getitem__(self, index):
        viewpoint_cam = self.viewpoint_stack[index]
        if viewpoint_cam.meta_only:
            viewpoint_image, gt_alpha_mask = self.load_image(viewpoint_cam)
            return viewpoint_image, viewpoint_cam, gt_alpha_mask
        else:
            viewpoint_image = viewpoint_cam.image
            gt_alpha_mask = viewpoint_cam.gt_alpha_mask
            return viewpoint_image, viewpoint_cam, gt_alpha_mask

    def load_image(self, viewpoint_cam):
        self._init_env()
        name = viewpoint_cam.file_path

        rgba_bytes = self.txn.get(self.key_map[name].encode())
        ### 1. 用pickle直接保存的np.ndarray数据
        # norm_data = pickle.loads(rgba_bytes) # H x W x 4
        ### 2. 用PIL保存的PNG图像
        # norm_data = np.array(Image.open(io.BytesIO(rgba_bytes)).convert("RGBA"))
        ### 3. 用npz保存的压缩数据
        norm_data = np.load(io.BytesIO(rgba_bytes))["data"]

        norm_data = norm_data.astype(np.float32) / 255.0
        arr = norm_data[:,:,:3]

        if norm_data[:, :, 3:4].min() < 1:
            arr = np.concatenate([arr, norm_data[:, :, 3:4]], axis=2)
            image_load = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGBA")
        else:
            image_load = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
        resized_image_rgb = PILtoTorch(image_load, viewpoint_cam.resolution)
        viewpoint_image = resized_image_rgb[:3, ...].clamp(0.0, 1.0)

        gt_alpha_mask = torch.ones((1, viewpoint_cam.image_height, viewpoint_cam.image_width))
        if resized_image_rgb.shape[0] == 4:
            gt_alpha_mask = resized_image_rgb[3:4, ...]

        # #----------------------- Motion Mask Test ------------------------------# #
        if viewpoint_cam.motion_mask_path is not None:
            motion_mask = cv2.imread(viewpoint_cam.motion_mask_path)
            motion_mask = torch.tensor(motion_mask[:,:,0:1] > 0.1).float()
            motion_mask = torch.nn.functional.interpolate(
                        motion_mask.permute(2,0,1)[None,], size=(viewpoint_cam.image_height, viewpoint_cam.image_width),
                        mode="bicubic", align_corners=True).squeeze()
            gt_alpha_mask *= motion_mask.to(viewpoint_image.device)

        return viewpoint_image, gt_alpha_mask
class SharedCameraDataset(Dataset):
    def __init__(self, viewpoint_stack, white_background, shared_cache, cache_lock, target_time=[-1], cache_activate=True):
        self.cache_activate = cache_activate
        if cache_activate:
            self.shared_cached_image = shared_cache
            self.cache_lock = cache_lock  # 多线程安全保护
        self.target_time = target_time
        self.viewpoint_stack_all = viewpoint_stack
        self.bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
        self.set_target_viewpoint_stack(self.target_time)


    def set_target_viewpoint_stack(self, target_time):
        self.target_time = target_time
        try:
            if len(target_time) == 1 and target_time[0] < 0:
                logger.info(f"Load ALL Viewpoints.")
                self.viewpoint_stack = self.viewpoint_stack_all
            elif len(target_time) == 2:
                start_ts = target_time[0]
                end_ts = target_time[1]
                tss = np.array(list(set([v.timestamp for v in self.viewpoint_stack_all])))
                logger.info(f"max timestamp = {tss.max()}, target time duration: {target_time}")
                self.viewpoint_stack = [viewpoint for viewpoint in self.viewpoint_stack_all
                                        if (viewpoint.timestamp > start_ts or np.isclose(viewpoint.timestamp, start_ts)) and
                                            (viewpoint.timestamp < end_ts or np.isclose(viewpoint.timestamp, end_ts))
                                        ]
            else:
                tss = np.array(list(set([v.timestamp for v in self.viewpoint_stack_all])))
                logger.info(f"max timestamp = {tss.max()}, target time: {target_time}")
                self.viewpoint_stack = [viewpoint for viewpoint in self.viewpoint_stack_all if any(np.isclose(viewpoint.timestamp, tt) for tt in target_time)]
        except:
            logger.warning(f"Invalid target time {target_time}, load all viewpoint")
            self.viewpoint_stack = self.viewpoint_stack_all

        self.true_viewpoint_idx = [idx for idx, v in enumerate(self.viewpoint_stack) if v.is_true_image]

    def __getitem__(self, index):
        viewpoint_cam = self.viewpoint_stack[index]
        if viewpoint_cam.meta_only:
            if self.cache_activate:
                with self.cache_lock:
                    if viewpoint_cam.image_path in self.shared_cached_image:
                        viewpoint_image, gt_alpha_mask = self.shared_cached_image[viewpoint_cam.image_path]
                        return viewpoint_image, viewpoint_cam, gt_alpha_mask

            viewpoint_image, gt_alpha_mask = self.load_image(viewpoint_cam)
            if self.cache_activate:
                with self.cache_lock:
                    if viewpoint_cam.is_true_image:
                        self.shared_cached_image[viewpoint_cam.image_path] = (viewpoint_image, gt_alpha_mask)
            return viewpoint_image, viewpoint_cam, gt_alpha_mask
        else:
            viewpoint_image = viewpoint_cam.image
            gt_alpha_mask = viewpoint_cam.gt_alpha_mask
            return viewpoint_image, viewpoint_cam, gt_alpha_mask

    def __len__(self):
        return len(self.viewpoint_stack)

    def load_image(self, viewpoint_cam):
        with Image.open(viewpoint_cam.image_path) as image_load:
            im_data = np.array(image_load.convert("RGBA"))
        norm_data = im_data / 255.0
        arr = norm_data[:,:,:3]
        # arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + self.bg * (1 - norm_data[:, :, 3:4]) ### for correct ssim/lpips loss, do not apply alpha mask to gt image
        if norm_data[:, :, 3:4].min() < 1:
            arr = np.concatenate([arr, norm_data[:, :, 3:4]], axis=2)
            image_load = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGBA")
        else:
            image_load = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
        resized_image_rgb = PILtoTorch(image_load, viewpoint_cam.resolution)
        viewpoint_image = resized_image_rgb[:3, ...].clamp(0.0, 1.0)

        gt_alpha_mask = torch.ones((1, viewpoint_cam.image_height, viewpoint_cam.image_width))
        if resized_image_rgb.shape[0] == 4:
            gt_alpha_mask = resized_image_rgb[3:4, ...]

        # #----------------------- Motion Mask Test ------------------------------# #
        if viewpoint_cam.motion_mask_path is not None:
            motion_mask = cv2.imread(viewpoint_cam.motion_mask_path)
            motion_mask = torch.tensor(motion_mask[:,:,0:1] > 0.1).float()
            motion_mask = torch.nn.functional.interpolate(
                        motion_mask.permute(2,0,1)[None,], size=(viewpoint_cam.image_height, viewpoint_cam.image_width),
                        mode="bicubic", align_corners=True).squeeze()
            gt_alpha_mask *= motion_mask.to(viewpoint_image.device)

        return viewpoint_image, gt_alpha_mask

def generate_dataloader(dataset, scene, gaussians, batch_size, train_target_time, test_target_time):
    # training_dataset = scene.getTrainCameras()
    # testing_dataset = scene.getTestCameras()
    training_dataset = scene.getTrainCameras(target_time=train_target_time)
    testing_dataset = scene.getTestCameras(target_time=test_target_time)
    logger.info(f"Create Dataset with {len(training_dataset)} training data, {len(testing_dataset)} testing data")
    logger.info(f"==> Training target time = {training_dataset.target_time}")
    logger.info(f"==> Testing target time = {testing_dataset.target_time}")

    if dataset.weighted_sample:
        #* Weighted Random Sample Data Loading
        logger.info("Weighted Random Sample Data Loading")
        viewpoint_stack = training_dataset.viewpoint_stack
        if dataset.motion_weight_sample:
            logger.info(f"Enable motion weight for data sampler")
            motion_weights = torch.ones(len(viewpoint_stack))
            motion_duration = [[0*gaussians.min_timestep, 30*gaussians.min_timestep]]
            # motion_duration = [[0*gaussians.min_timestep, 30*gaussians.min_timestep],
            #                    [150*gaussians.min_timestep, 225*gaussians.min_timestep]]
            # motion_duration = [[0*gaussians.min_timestep, 30*gaussians.min_timestep],
            #                    [120*gaussians.min_timestep, 225*gaussians.min_timestep],
            #                    [240*gaussians.min_timestep,285*gaussians.min_timestep]]
            for i, cam in enumerate(viewpoint_stack):
                for duration in motion_duration:
                    if ((cam.timestamp > duration[0] or np.isclose(cam.timestamp, duration[0])) and
                        (cam.timestamp < duration[1] or np.isclose(cam.timestamp, duration[1]))):
                        motion_weights[i] = 3.0
        else:
            motion_weights = None
        sampler = EpochAwareStratifiedSampler(
            data_source = training_dataset,
            pseudo_mask = torch.tensor([not cam.is_true_image for cam in viewpoint_stack], dtype=torch.bool),
            timestamps  = torch.tensor([cam.timestamp for cam in viewpoint_stack], dtype=torch.float),
            view_ids = torch.tensor([cam.camera_id for cam in viewpoint_stack], dtype=torch.int),
            time_bin_size = 15*gaussians.min_timestep,
            repeat_real = True,
            motion_weights = motion_weights
        )
        logger.info(f"Epoch Aware Stratified Sampler inited")

    else:
        sampler = None
        logger.info(f"No Sampler inited")
    training_dataloader = DataLoader(training_dataset, batch_size=batch_size, collate_fn=lambda x: x, drop_last=True, sampler=sampler,
                                     num_workers=16 if dataset.dataloader else 0, pin_memory=True, prefetch_factor=16 if dataset.dataloader else None)
    logger.info(f"DataLoader create, len = {len(training_dataloader)}, batch size = {batch_size}")
    # training_dataloader = DataLoader(training_dataset, batch_size=batch_size, collate_fn=lambda x: x, drop_last=True, sampler=sampler)

    return training_dataset, testing_dataset, training_dataloader, sampler

import random
class FullCoverWeightedSampler(torch.utils.data.Sampler):
    def __init__(self, weights, is_real_mask):
        self.weights = np.array(weights)
        self.is_real_mask = np.array(is_real_mask)
        self.real_indices = np.where(self.is_real_mask)[0]
        self.fake_indices = np.where(~self.is_real_mask)[0]
        self.real_probs = self.weights[self.real_indices] / self.weights[self.real_indices].sum()
        self.fake_probs = self.weights[self.fake_indices] / self.weights[self.fake_indices].sum()

        self.n_real = len(self.real_indices)
        self.n_fake = len(self.fake_indices)
        self.n_target = max(self.n_real, self.n_fake)
        self.extra_real = max(0, self.n_target - self.n_real)
        self.extra_fake = max(0, self.n_target - self.n_fake)

        # 1. 全部采一遍
        real_once = self.real_indices.tolist()
        fake_once = self.fake_indices.tolist()
        # 2. 额外重复真实图像以达到1:1
        real_extra = np.random.choice(
            self.real_indices, size=self.extra_real, replace=True, p=self.real_probs
        ).tolist()
        fake_extra = np.random.choice(
            self.fake_indices, size=self.extra_fake, replace=True, p=self.fake_probs
        ).tolist()
        # 3. 合并并打乱
        all_indices = real_once + fake_once + real_extra + fake_extra
        random.shuffle(all_indices)
        self.all_indices = all_indices

    def __iter__(self):
        return iter(self.all_indices)

    def __len__(self):
        return len(self.all_indices)

from collections import defaultdict
import random
class EpochAwareStratifiedSampler(torch.utils.data.Sampler):
    def __init__(self, data_source, pseudo_mask, timestamps, view_ids,
                 time_bin_size=15, repeat_real=True,
                 motion_weights=None, seed=0):
        self.data_source = data_source
        self.pseudo_mask = torch.as_tensor(pseudo_mask, dtype=torch.bool)
        self.real_mask = ~self.pseudo_mask
        self.timestamps = torch.as_tensor(timestamps)
        self.view_ids = torch.as_tensor(view_ids)
        self.time_bin_size = time_bin_size
        self.pseudo_per_epoch = len(data_source) // 2
        self.repeat_real = repeat_real
        self.motion_weights = torch.as_tensor(motion_weights) if motion_weights is not None else torch.ones_like(self.timestamps)
        self.seed = seed

        self.real_idxs = torch.where(self.real_mask)[0].tolist()
        self.pseudo_idxs = torch.where(self.pseudo_mask)[0].tolist()

        self.pseudo_bins = defaultdict(list)
        for idx in self.pseudo_idxs:
            t_bin = int(self.timestamps[idx] // self.time_bin_size)
            self.pseudo_bins[t_bin].append(idx)

        self.real_bins = defaultdict(list)
        for idx in self.real_idxs:
            t_bin = int(self.timestamps[idx] // self.time_bin_size)
            self.real_bins[t_bin].append(idx)

        self.epoch = 0

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __len__(self):
        return self.pseudo_per_epoch * 2 if self.repeat_real else self.pseudo_per_epoch + len(self.real_idxs)

    def _stratified_sample(self, bins, total, rng, weighted=False):
        bin_keys = sorted(bins.keys())
        per_bin = max(1, total // len(bin_keys))
        sampled = []
        for b in bin_keys:
            bin_indices = bins[b]
            if len(bin_indices) <= per_bin:
                sampled.extend(bin_indices)
            else:
                if weighted:
                    weights = self.motion_weights[bin_indices]
                    weights = weights / weights.sum()
                    chosen = rng.choices(bin_indices, weights=weights.tolist(), k=per_bin)
                else:
                    chosen = rng.sample(bin_indices, k=per_bin)
                sampled.extend(chosen)
        if len(sampled) < total:
            sampled.extend(rng.choices(sampled, k=total - len(sampled)))
        return sampled[:total]

    def __iter__(self):
        rng = random.Random(self.seed + self.epoch)

        pseudo_sampled = self._stratified_sample(self.pseudo_bins, self.pseudo_per_epoch, rng, weighted=True)

        if self.repeat_real:
            real_sampled = self._stratified_sample(self.real_bins, self.pseudo_per_epoch, rng, weighted=True)
        else:
            real_sampled = self.real_idxs

        full_batch = pseudo_sampled + real_sampled
        rng.shuffle(full_batch)
        return iter(full_batch)



class CameraDataset(Dataset):

    def __init__(self, viewpoint_stack, white_background, target_time=[-1]):
        self.target_time = target_time
        try:
            if len(target_time) == 1 and target_time[0] < 0:
                self.viewpoint_stack = viewpoint_stack

            elif len(target_time) == 2:
                start_ts = target_time[0]
                end_ts = target_time[1]
                self.viewpoint_stack = [viewpoint for viewpoint in viewpoint_stack if viewpoint.timestamp >= start_ts and viewpoint.timestamp <= end_ts]

            else :
                #* only apply for pseudo images
                # self.viewpoint_stack = []
                # for viewpoint in viewpoint_stack:
                #     if viewpoint.is_true_image:
                #         self.viewpoint_stack.append(viewpoint)
                #     elif viewpoint.timestamp in target_time:
                #         self.viewpoint_stack.append(viewpoint)

                #* apply for all images
                self.viewpoint_stack = [viewpoint for viewpoint in viewpoint_stack if viewpoint.timestamp in target_time]
        except:
            logger.warning(f"Invalid target time {target_time}, load all viewpoint")
            self.viewpoint_stack = viewpoint_stack

        self.bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
        self.cached_image = {}

    def __getitem__(self, index):
        viewpoint_cam = self.viewpoint_stack[index]

        if viewpoint_cam.meta_only:
            if viewpoint_cam.image_path in self.cached_image:
                viewpoint_image = self.cached_image[viewpoint_cam.image_path]
            else:
                viewpoint_image = self.load_image(viewpoint_cam)
                if viewpoint_cam.is_true_image:
                    self.cached_image[viewpoint_cam.image_path] = viewpoint_image.shared_memory_()
        else:
            viewpoint_image = viewpoint_cam.image

        return viewpoint_image, viewpoint_cam

    def __len__(self):
        return len(self.viewpoint_stack)

    def load_image(self, viewpoint_cam):
        with Image.open(viewpoint_cam.image_path) as image_load:
            im_data = np.array(image_load.convert("RGBA"))
        norm_data = im_data / 255.0
        arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + self.bg * (1 - norm_data[:, :, 3:4])
        image_load = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
        resized_image_rgb = PILtoTorch(image_load, viewpoint_cam.resolution)
        viewpoint_image = resized_image_rgb[:3, ...].clamp(0.0, 1.0)
        if resized_image_rgb.shape[1] == 4:
            gt_alpha_mask = resized_image_rgb[3:4, ...]
            viewpoint_image *= gt_alpha_mask
        else:
            viewpoint_image *= torch.ones((1, viewpoint_cam.image_height, viewpoint_cam.image_width))
        return viewpoint_image



def create_balanced_dataloader(dataset, batch_size, num_workers=12, pin_memory=True):
    """
    根据 is_true_image 字段进行均衡采样，构造 DataLoader。

    参数:
        dataset: 原始训练数据集，dataset[i] 应为 (viewpoint_image, viewpoint_cam)
        batch_size: DataLoader 的 batch size
        num_workers: DataLoader 的线程数
        pin_memory: 是否启用 pin_memory

    返回:
        DataLoader 实例，包含均衡采样后的子集。
    """

    # 根据 Camera 的 is_true_image 字段分类 index
    true_indices = [i for i, item in enumerate(dataset) if item[-1].is_true_image]
    pseudo_indices = [i for i, item in enumerate(dataset) if not item[-1].is_true_image]

    true_cnt = len(true_indices)
    pseudo_cnt = len(pseudo_indices)

    # 避免除零错误
    if true_cnt == 0 or pseudo_cnt == 0:
        raise ValueError("Dataset must contain both true and pseudo images.")

    # 均衡采样（可替换为非均衡采样策略）
    sampled_true = np.random.choice(true_indices, size=true_cnt, replace=True)
    sampled_pseudo = np.random.choice(pseudo_indices, size=pseudo_cnt, replace=True)

    # 合并采样结果并打乱
    sampled_indices = np.concatenate([sampled_true, sampled_pseudo])
    np.random.shuffle(sampled_indices)

    # 创建子集和 DataLoader
    subset = Subset(dataset, sampled_indices)
    loader = DataLoader(
        subset,
        batch_size=batch_size,
        shuffle=False,  # 已手动 shuffle
        collate_fn=lambda x: x,
        drop_last=True,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    return loader

def load_image(image_path, viewpoint_cam, white_background):
    bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
    with Image.open(image_path) as image_load:
        im_data = np.array(image_load.convert("RGBA"))
    norm_data = im_data / 255.0
    arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
    image_load = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
    resized_image_rgb = PILtoTorch(image_load, viewpoint_cam.resolution)
    viewpoint_image = resized_image_rgb[:3, ...].clamp(0.0, 1.0)
    if resized_image_rgb.shape[1] == 4:
        gt_alpha_mask = resized_image_rgb[3:4, ...]
        viewpoint_image *= gt_alpha_mask
    else:
        viewpoint_image *= torch.ones((1, viewpoint_cam.image_height, viewpoint_cam.image_width))
    return viewpoint_image
