from typing import List, Tuple, Union, Optional
import torch
import math
from torch.utils.data import Dataset, DataLoader
from dataloader.code.dataset import BlendableDataset
from dataloader.code.input_specs import RLTaskInput
from torch.utils.data.distributed import DistributedSampler

def my_collate_fn(data_list: Union[List[Tuple[RLTaskInput, Tuple]], List[tuple]]):
    """
    This is a collate function that
    try to concatenate all tasks of the same type
    in a batch so that it will be more efficient
    than multiple forwards
    """
    task_list = []
    batch_data_info, batch_raw_obs = None, None
    if isinstance(data_list[0], tuple):
        batch_data_info, batch_raw_obs = [], []
        for task in data_list:
            batch_data_info.append(task[1])
            task_list.append(task[0][0])
            batch_raw_obs.append(task[0][1])
    else:
        task_list = data_list

    task_merged = RLTaskInput.merge_into_one(task_list)
    prefix_mask = task_merged.prefix_mask
    task_merged.prefix_mask = None          # prefix_mask中各样本对应timestep不同，无法合并单独处理
    task_merged.apply(torch.cat, dim=0)     # 各成员 tensor 尺寸均为 (data_num, seq_len)
    task_merged.prefix_mask = prefix_mask   # len=data_num
    return task_merged, batch_data_info, batch_raw_obs

    '''
    # rec_dict 管理各种不同类型的 TaskInput 对象
    rec_dict = defaultdict(list)
    for task in task_list:
        task_type_name = type(task).__name__
        rec_dict[task_type_name].append(task)

    # res 是一个嵌套列表，每个元素是一类 TaskInput 对象列表
    res = list(rec_dict.values())

    # 使用 map 方法对 res 中各类 TaskInput 应用 GatoInputBase.merge_into_one 方法
    # 每一类 TaskInput 生成一个合并后的 TaskInput 对象，其中属性都是 (1, 1024) tensor 列表，
    # 再用 list 方法把这些 TaskInput 组成列表 merged
    merged = list(map(GatoInputBase.merge_into_one, res))

    # 使用 map 方法调用 merged list 中各类 TaskInput 的 apply 方法
    # 把各 TaskInput 的各个 tensor 属性拼接起来，尺寸均为 (data_num, 1024)
    list(map(lambda x: x.apply(torch.cat, dim=0), merged))
    
    merged.append(batch_data_info)
    return merged
    '''

def build_training_data_loader(
    args, 
    dataset: BlendableDataset, 
    consumed_samples: int=0,                    # 已经加载的样本数，目前不考虑在多张卡上加载数据所以此参数全部使用缺省值0
    epoch_total_samples: Optional[int]=None,    # 一个 epoch 的总样本量
    is_eval=False,
    seed=0,
    current_epoch=0
):
    """Buld dataloader given an input dataset."""
    if dataset is None:
        return None

    if epoch_total_samples is None:
        epoch_total_samples = len(dataset)

    if args.dataloader_type == "DDP":
        batch_size = args.eval_batch_size if is_eval else args.batch_size
        return DataLoader(
            dataset,
            batch_size=batch_size,
            pin_memory=False,
            shuffle=False,
            num_workers=args.num_workers,
            sampler=MyDistributedSampler(   # 这个 sampler 自动将数据分块后送个各个 GPU，它能避免数据重叠
                dataset=dataset,
                epoch_total_samples=epoch_total_samples,
                seed=seed,
                current_epoch=current_epoch
            ),    
            collate_fn=my_collate_fn
        )

    # Sampler
    if args.dataloader_type == "sequential":
        sampler = SequentialPretrainingSampler
    elif args.dataloader_type == "random":
        if is_eval:
            sampler = SequentialPretrainingSampler
        else:
            sampler = RandomPretrainingSampler
    else:
        raise Exception(f"{args.dataloader_type} dataloader type is not supported.")

    batch_sampler = sampler(
        epoch_total_samples=epoch_total_samples,
        dataset_total_samples=len(dataset),
        consumed_samples=consumed_samples,
        batch_size=args.batch_size if not is_eval else args.eval_batch_size,
        seed=seed,
    )

    # Torch dataloader.
    return torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batch_sampler,
        # batch_size=args.micro_batch_size,
        # shuffle=True,
        num_workers=args.num_workers,
        pin_memory=False,
        #persistent_workers=True,
        collate_fn=my_collate_fn,
    )

class SequentialPretrainingSampler:
    ''' 
        每个 epoch 按某种排序从数据集采样 epoch_total_samples 个样本，其中禁止出现重复数据。
        直到多个 epoch 后把全体数据集过一遍, 再按新的排序重复 
    '''
    def __init__(
        self,
        epoch_total_samples,                                              
        dataset_total_samples,
        consumed_samples,
        batch_size,
        seed,
        drop_last=True,
    ):
        self.epoch_total_samples = epoch_total_samples              # 一个 epoch 的总样本量
        self.dataset_total_samples = dataset_total_samples          # 数据集总样本量
        self.consumed_samples = consumed_samples                    # 本采样器已采样的总样本量（多个 epoch 的总量）
        self.current_data_offset = 0                                # 当前在样本排列中的位置
        self.batch_size = batch_size
        self.seed = seed
        self.last_batch_size = self.epoch_total_samples % self.batch_size
        self.active_total_samples = self.epoch_total_samples - self.last_batch_size   # 总有效样本量（舍弃最后一个不完整的bacth）
        self.drop_last = drop_last
        self.epoch = 0

        # Sanity checks.
        if self.epoch_total_samples > self.dataset_total_samples:
            print(f'Note that some samples will be repeated in one epoch: there are {dataset_total_samples} samples in dataset and we want {epoch_total_samples}')
        assert self.epoch_total_samples > 0, f"no sample to consume: {self.epoch_total_samples}"
        assert self.batch_size > 0
        
        # 初始样本排序
        try:
            g = torch.Generator()
            g.manual_seed(seed)
            self.sample_idxs = torch.randperm(self.dataset_total_samples, generator=g) 
        except RuntimeError:
            g = torch.Generator('cuda')
            g.manual_seed(seed)
            self.sample_idxs = torch.randperm(self.dataset_total_samples, generator=g) 

    def __len__(self):
        return self.epoch_total_samples

    def __iter__(self):
        ''' 各个 epoch 中的数据不重叠, 直到按某种排序把全体数据集过一遍, 然后按新的排序重复 '''
        # 当前 epoch
        self.epoch = self.consumed_samples // self.active_total_samples     
        
        # 加上本 epoch 数据后超出数据总量，重新生成排列
        if self.current_data_offset + self.epoch_total_samples > self.dataset_total_samples:
            try:
                g = torch.Generator()
                g.manual_seed(self.seed + self.epoch)
                self.sample_idxs = torch.randperm(self.dataset_total_samples, generator=g) 
            except RuntimeError:
                g = torch.Generator(device='cuda')
                g.manual_seed(self.seed + self.epoch)
                self.sample_idxs = torch.randperm(self.dataset_total_samples, generator=g) 
            self.current_data_offset = 0
        
        # 从当前排列中依次取 epoch_total_samples 长度，分 batch 返回
        sample_idxs = self.sample_idxs[self.current_data_offset:self.current_data_offset + self.epoch_total_samples]
        for i in range(0, len(sample_idxs), self.batch_size):
            if i + self.batch_size > len(sample_idxs):
                # Check the last partial batch and see drop_last is set
                batch = sample_idxs[i:].tolist()
                self.consumed_samples += len(batch)
                self.current_data_offset += len(batch)
                if not self.drop_last:
                    yield batch
            else:
                batch = sample_idxs[i:i+self.batch_size].tolist()
                self.consumed_samples += len(batch)
                self.current_data_offset += len(batch)
                yield batch

class RandomPretrainingSampler:
    ''' 每个 epoch, 从 dataset_total_samples 中尽量均匀地采样 epoch_total_samples 个样本, 允许重复'''
    def __init__(
        self,
        epoch_total_samples,
        dataset_total_samples,
        consumed_samples,
        batch_size,
        seed,
        drop_last=True,
    ):
        self.epoch_total_samples = epoch_total_samples                          # 一个 epoch 的总样本量
        self.dataset_total_samples = dataset_total_samples                      # 数据集总样本量
        self.consumed_samples = consumed_samples                                # 本采样器已采样的总样本量（多个 epoch 的总量）
        self.batch_size = batch_size
        self.seed = seed
        self.last_batch_size = self.epoch_total_samples % self.batch_size
        self.active_total_samples = self.epoch_total_samples - self.last_batch_size   # 总有效样本量（舍弃最后一个不完整的bacth）
        self.drop_last = drop_last
        
        # Sanity checks.
        #assert self.epoch_total_samples <= self.dataset_total_samples
        if self.epoch_total_samples > self.dataset_total_samples:
            print(f'Note that some samples will be repeated in one epoch: there are {dataset_total_samples} samples in dataset and we want {epoch_total_samples}')
        assert self.epoch_total_samples > 0, f"no sample to consume: {self.epoch_total_samples}"
        assert self.batch_size > 0

    def __len__(self):
        return self.epoch_total_samples

    def __iter__(self):
        '''每个epoch从数据集所有样本中随机取出不连续的self.total_samples个'''
        # 每个epoch重置随机数生成器
        self.epoch = self.consumed_samples // self.active_total_samples     # 当前 epoch
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)

        # 从 dataset_total_samples 中尽量均匀地采样 epoch_total_samples 个样本
        if self.epoch_total_samples <= self.dataset_total_samples:
            dataset_idxs = torch.randperm(self.dataset_total_samples, generator=g) 
            sample_idxs = dataset_idxs[:self.epoch_total_samples]
        else:
            sample_idxs = torch.randperm(self.epoch_total_samples, generator=g) 

        # 逐 batch 返回整个 epoch 数据
        for i in range(0, len(sample_idxs), self.batch_size):
            if i + self.batch_size > len(sample_idxs):
                # Check the last partial batch and see drop_last is set
                batch = sample_idxs[i:].tolist()
                self.consumed_samples += len(batch)
                if not self.drop_last:
                    yield batch
            else:
                batch = sample_idxs[i:i+self.batch_size].tolist()
                self.consumed_samples += len(batch)
                yield batch

class MyDistributedSampler(DistributedSampler):
    def __init__(
            self, 
            dataset: Dataset,                   # 采样目标数据集
            num_replicas: Optional[int] = None, # DDP 并行的进程数
            rank: Optional[int] = None,         # 当前进程 rank
            shuffle: bool = True,               # 是否要打乱顺序（用这个打乱的话需要每个epoch入口处设置本类的epoch变量值）
            seed: int = 0,                      # 用于打乱的随机种子
            drop_last: bool = False,            # 若无法在各个进程间均匀分割数据，是否丢弃尾部无法均分的部分
            epoch_total_samples: int = 0,       # 一个 epoch 的总样本量
            current_epoch: int = 0              # 起始 epoch，从 snapshot 或 ckpt 得到
        ) -> None:
        super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
        self.epoch = current_epoch
        self.shuffle = shuffle
        self.seed = seed
        self.epoch_total_samples = epoch_total_samples
        self.dataset_total_samples = len(dataset)

        # 计算一个 epoch 需要返回的索引总量 epoch_total_samples
        if self.drop_last and self.epoch_total_samples % self.num_replicas != 0: 
            # 如果 drop_last 且无法整除，则每个进程分到的样本数向下取整 
            self.num_samples = math.ceil((self.epoch_total_samples - self.num_replicas) / self.num_replicas)
        else:
            # 如果不 drop_last（需填充重复元素补满）或可以整除，则每个进程分到的样本数向上取整
            self.num_samples = math.ceil(self.epoch_total_samples / self.num_replicas)             
        self.epoch_total_samples = self.num_samples * self.num_replicas

        # 计算遍历一次数据集的总数据量 total_samples
        if self.epoch_total_samples > self.dataset_total_samples:
            print(f'Note that some samples will be repeated in one epoch: there are {self.dataset_total_samples} samples in dataset and we want {self.epoch_total_samples}')
            self.total_samples = self.epoch_total_samples
        else:
            last_data_num = self.dataset_total_samples % self.epoch_total_samples
            self.total_samples = self.dataset_total_samples + (self.epoch_total_samples - last_data_num)    # 令 self.total_samples 可以整除 self.epoch_total_samples
        assert self.total_samples % self.epoch_total_samples == 0

        #self.current_data_offset = self.epoch_total_samples * current_epoch
        self.current_data_offset = self.epoch_total_samples * int(current_epoch % (self.total_samples / self.epoch_total_samples))
        self.init_current_data_offset = self.current_data_offset

        # 初始样本排序
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(seed)
            self.sample_idxs = torch.randperm(self.total_samples, generator=g).tolist()
        else:
            self.sample_idxs = list(range(self.total_samples))
        
    def __iter__(self):
        # 从排列中依次截取长度 epoch_total_samples 的片段
        indices = self.sample_idxs[self.current_data_offset : self.current_data_offset + self.epoch_total_samples]
        assert len(indices) == self.epoch_total_samples
        self.current_data_offset += self.epoch_total_samples

        # 已经完全遍历一次，按 shuffle 标记重新生成排序
        if self.current_data_offset >= self.total_samples:
            self.current_data_offset = 0
            if self.shuffle:
                g = torch.Generator()
                g.manual_seed(self.seed + self.epoch)
                self.sample_idxs = torch.randperm(self.total_samples, generator=g).tolist()
        
        # subsample
        indices = indices[self.rank : self.epoch_total_samples : self.num_replicas]
        assert len(indices) == self.num_samples
        return iter(indices)

    def __len__(self) -> int:
        return self.num_samples

    def reset(self):
        self.current_data_offset = self.init_current_data_offset

    def set_epoch(self, epoch: int, is_train: bool) -> None:
        self.epoch = epoch
        if is_train:
            # 训练集，确保每个 epoch 取数据无重叠
            assert self.current_data_offset == self.epoch_total_samples * int(epoch % (self.total_samples / self.epoch_total_samples))
        else:
            # 验证集，每 n 个训练 epoch 执行一次，每次重新随机取数据
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            self.sample_idxs = torch.randperm(self.total_samples, generator=g).tolist()
            self.current_data_offset = 0