import math, copy
import numpy as np

from torch.utils.data import Dataset
from utils import logs_handler, process_image

logger = logs_handler.get_logger(__name__)

def check_terminal_idxs(terminal_idxs):
    idxs = terminal_idxs.tolist() if isinstance(terminal_idxs, np.ndarray) else terminal_idxs
    for i, j in zip(idxs, list(sorted(idxs))):
        assert i == j, f'{i} != {j}'

def make_terminal_range(terminal_idxs):
    idxs = terminal_idxs.tolist() if isinstance(terminal_idxs, np.ndarray) else terminal_idxs
    return zip([0] + idxs[:-1], idxs)

def get_episode_length(terminal_idxs):
    rngs = make_terminal_range(terminal_idxs)
    episode_length = []
    for start, end in rngs:
        episode_length.append(end - start)
    return episode_length

def get_cumulative_rewards(rewards, terminal_idxs=None, gamma=1.0):
    if terminal_idxs is None:
        terminal_idxs = [len(rewards)]
    terminals_rngs = make_terminal_range(terminal_idxs)
    rtgs = []
    for start, end in terminals_rngs:
        gt = [0.0]
        for r in rewards[start:end]:
            gt.append(r + gamma * gt[-1])
        rtgs.extend(gt[1:][::-1])
    rtgs = np.array(rtgs)
    assert len(rewards) == len(rtgs), f'{len(rewards)} | {len(rtgs)}'
    return rtgs

def gather_returns(rtgs, terminal_idxs=None):
    if terminal_idxs is None:
        terminal_idxs = [len(rtgs)]
    terminals_rngs = make_terminal_range(terminal_idxs)
    returns = []
    for start, _ in terminals_rngs:
        returns.append(rtgs[start])
    return returns

def target_to_segments(target, terminal_idxs, min_length=None, max_length=None):
    terminal_ranges = list(make_terminal_range(terminal_idxs))
    size = len(target)
    min_length = min_length or 0
    max_length = max_length or size
    target = np.array(target)
    pairs = []
    past_idx = 0
    for future_idx in range(1, size):
        if target[past_idx] != target[future_idx]:
            length = future_idx - past_idx
            if (min_length <= length <= max_length)\
                and (not check_overlap(past_idx, future_idx, terminal_ranges)):
                pairs.append((past_idx, future_idx))
            else:
                while past_idx < future_idx:
                    subset_idx = min(past_idx + max_length, future_idx)
                    if (min_length <= (subset_idx - past_idx)) and\
                        (not check_overlap(past_idx, subset_idx, terminal_ranges)):
                        pairs.append((past_idx, subset_idx))
                    past_idx = subset_idx
            past_idx = future_idx
    return pairs

def check_overlap(start, end, terminal_ranges):
    """terminal_ranges:  pre-sorted"""
    l, r = 0, len(terminal_ranges) - 1
    while l <= r:
        mid = l + (r - l) // 2
        s, e = terminal_ranges[mid]
        if (start >= s) and (end < e):
            return False
        elif start > s:
            l = mid + 1
        else:
            r = mid - 1 
    return True

def find_episode_idx(end, terminal_idxs):
    """terminal_idxs:  pre-sorted"""
    k = len(terminal_idxs)
    l, r = 0, k - 1
    while l <= r:
        mid = l + (r - l) // 2
        if (end < terminal_idxs[mid]) and ((mid == 0) or (end >= terminal_idxs[mid - 1])):
            return mid
        elif end < terminal_idxs[mid]:
            r = mid - 1
        else:
            l = mid + 1
    assert False, '...'

def find_episode_rng(index, terminal_ranges):
    l, r = 0, len(terminal_ranges)
    while l <= r:
        mid = l + (r - l) // 2
        low, high = terminal_ranges[mid]
        if low <= index < high:
            return mid
        elif low <= index:
            l = mid + 1
        else:
            r = mid - 1
    assert False, '...'

def sample_indices(total_len, sample_ratio):
    indices = np.arange(total_len)
    sample_size = int(math.ceil(sample_ratio * total_len))
    assert 0 < sample_size <= total_len, '...'
    np.random.shuffle(indices)
    return indices[:sample_size]

# Note: we may/ will not need make_split function b/z we can use learned policy to generate new data/trajs then use it for evaluation instead of splitting data   
def make_split(terminal_idxs, test_size):
    num_episodes = len(terminal_idxs)
    split_indices = sample_indices(num_episodes, test_size)
    tset = set(split_indices.tolist())
    train_indices, test_indices = [], []
    train_terminal_idxs, test_terminal_idxs = [0], [0]
    for k, (l, r) in enumerate(make_terminal_range(terminal_idxs)):
        if k in tset:
            test_indices += list(range(l, r))
            test_terminal_idxs.append(test_terminal_idxs[-1] + r - l)
        else:
            train_indices += list(range(l, r))
            train_terminal_idxs.append(train_terminal_idxs[-1] + r - l)
    train_split = np.array(train_indices), np.array(train_terminal_idxs[1:])
    test_split = np.array(test_indices), np.array(test_terminal_idxs[1:])
    return train_split, test_split

# is TemporalIterator generic class? for what purpose?
class TemporalIterator:

    # what is num_split argument here, and why? > max_seq_len
    def __init__(self, time_steps, max_seq_len, start_index=0, sampling_rate=1):
        assert 0 <= start_index < time_steps, '...'
        assert 0 < max_seq_len <= time_steps, '...'

        self.max_seq_len = max_seq_len
        self.time_steps = time_steps
        self.start_index = start_index
        self.sampling_rate = sampling_rate

        # what's the meaning of 126, 128?? > ?
        self.step = sampling_rate * max_seq_len
        self.current = start_index - self.step

    def __iter__(self):
        return self

    def __next__(self):
        self.current += self.step
        if self.current + self.step <= self.time_steps:
            return self.current
        raise StopIteration

    def __len__(self):
        return self.time_steps // self.step

    @property
    def indices(self):
        return list(self)


class BaseDataset(Dataset):

    def __init__(self, frames, actions=None, rewards=None, extra_data_dict=None, terminal_idxs=None, gamma=1.0, 
                 image_size=(84, 84), num_frames=4, frame_rate=1, skip_frames=0, overlap_ratios=None, num_clips=None, 
                 channel_last=False, use_dynamic_range=False, use_rtg=False, indices=None, actions_discrete=True):
        
        self.image_size = image_size
        self.num_frames = num_frames
        self.frame_rate = frame_rate
        self.skip_frames = skip_frames

        self.channel_last = channel_last
        self.use_dynamic_range = use_dynamic_range

        self.gamma = gamma
        self.use_rtg = use_rtg
        self.indices = np.arange(0, len(frames)) if (indices is None) else indices
        self.frames, self.actions, self.rewards =  self.subset(frames), self.subset(actions), self.subset(rewards)
        self.extra_data_dict = {} if (extra_data_dict is None) else copy.deepcopy(extra_data_dict)
        for key in self.extra_data_dict:
            self.extra_data_dict[key] = self.subset(self.extra_data_dict[key])

        self.rtgs = None
        self.terminal_idxs = np.array([len(self.frames)]) if terminal_idxs is None else terminal_idxs
        check_terminal_idxs(self.terminal_idxs)
        self.terminal_ranges = list(make_terminal_range(terminal_idxs))
        self.actions_discrete = actions_discrete
        self.overlap_ratios = overlap_ratios or 'all'
        
        if self.rewards is not None:
            self.rtgs = get_cumulative_rewards(self.rewards, terminal_idxs=terminal_idxs, gamma=gamma)
        
        assert (not isinstance(self.overlap_ratios, str)) or (self.overlap_ratios in {'all'}), '...'
        
        if isinstance(self.overlap_ratios, str) and (self.overlap_ratios == 'all'):
            self.temporal_indices = list(range(0, self.steps))
        else:
            temporal_indices = set(self.get_temporal_indices(start_index=self.skip_frames))
            visited = set([self.skip_frames])
            for ratio in self.overlap_ratios:
                start_idx = max(self.skip_frames, int(self.num_frames * (1.0 - ratio)))
                if start_idx not in visited:
                    temporal_indices = temporal_indices.union(set(self.get_temporal_indices(start_index=start_idx)))
                    visited.add(start_idx)
            self.temporal_indices = list(temporal_indices)
        
        logger.info(f'Number of Cutoffs: {len(self.temporal_indices)}')        
        self.temporal_indices_groups = {}
        if terminal_idxs is not None:
            for i in range(len(self.temporal_indices)):
                start = self.temporal_indices[i]
                end = start + frame_rate * num_frames
                if check_overlap(start, end, self.terminal_ranges):
                    self.temporal_indices[i] = -1
            for episode_id, (low, high) in enumerate(self.terminal_ranges):
                for index in range(low, high):
                    self.temporal_indices_groups[index] = episode_id

            self.temporal_indices = list(filter(lambda idx: idx != -1, self.temporal_indices))
            self.episode_length = get_episode_length(terminal_idxs)
            self.num_episodes = len(terminal_idxs)

            logger.info('Cutoffs got updated!')
            logger.info(f'Number of Cutoffs: {len(self.temporal_indices)}')
            
            logger.info(f'Number of Episodes: {len(terminal_idxs)}')
            logger.info(f'Min Episode Length: {min(self.episode_length)}')
            logger.info(f'Max Episode Length: {max(self.episode_length)}')
            
        self.check_if_all_valid()
        assert (num_clips is None) or (self.num_frames % num_clips == 0), 'num-frames must be divisible by num-clips'
        logger.info(f'Ready!')

    def __getitem__(self, index):
        raise NotImplementedError('...')

    def process_observation(self, index):
        raise NotImplementedError('...')
    
    def check_if_all_valid(self):
        all_items = [self.frames, self.actions, self.rewards, self.rtgs]
        all_items += [val for val in self.extra_data_dict.values()]
        all_items = filter(lambda items: items is not None, all_items)
        sizes = set([len(items) for items in all_items])
        assert len(sizes) == 1, f'inconsistent size > {sizes}'
        
    def subset(self, items):
        if items is None:
            return None
        return items[self.indices]
    
    def get_temporal_indices(self, start_index):
        indices = TemporalIterator(time_steps=self.steps,
                                   max_seq_len=self.num_frames,
                                   start_index=start_index,
                                   sampling_rate=self.frame_rate).indices

        return indices

    def process_frame(self, index, exclude_observation=False):
        target = {}
        image = None

        if not exclude_observation:
            image = self.process_observation(index)
        
        episode_id = self.temporal_indices_groups[index]
        low, _ = self.terminal_ranges[episode_id]

        if self.actions is not None:
            target['action'] = self.actions[index]
        if self.rewards is not None:
            target['reward'] = self.rtgs[index] if self.use_rtg else self.rewards[index]
        target['episode_id'] = episode_id
        target['episode_length'] = self.episode_length[episode_id]
        target['timestep'] = index - low
        for key in self.extra_data_dict:
            target[key] = self.extra_data_dict[key][index]
        return image, target

    def process_sequential_frames(self, start_index, num_frames=None, 
                                  frame_rate=None, exclude_observation=False):
        start_index = max(self.skip_frames, start_index)
        num_frames = num_frames or self.num_frames
        frame_rate = frame_rate or self.frame_rate

        frames, targets = [], {}
 
        for i in range(0, num_frames):
            frame_idx = start_index + i * frame_rate

            frame, target = self.process_frame(frame_idx, exclude_observation=exclude_observation)
            frames.append(frame)
                
            for key, value in target.items():
                if key not in targets:
                    targets[key] = []
                targets[key].append(value)

        return frames, targets

    def process_segment(self, start, end):
        num_frames = end - start
        inputs, targets = self.process_sequential_frames(start, num_frames=num_frames, frame_rate=1)
        for key in targets:
            targets[key] = targets[key][0]
        return inputs, targets

    def make_ready(self, inputs, targets=None):
        if isinstance(inputs, list):
            inputs = np.array(inputs)
        if targets is not None:
            for key in targets:
                targets[key] = np.array(targets[key])
            return inputs, targets
        return inputs
    
    def normalize(self, inputs):
        inputs = inputs.astype('float32')
        if self.use_dynamic_range:
            inputs = process_image.normalize(inputs)
        else:
            inputs /= 255.0
        return inputs
    
    @property
    def steps(self):
        nk = self.num_frames * self.frame_rate - self.skip_frames
        count = max(0, len(self.frames) - nk)
        return count
