import h5py
import pickle
import numpy as np
import torch
import collections
from pathlib import Path


def discount_cumsum(x, gamma):
    new_x = np.zeros_like(x)
    new_x[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        new_x[t] = x[t] + gamma * new_x[t + 1]
    return new_x

def discount_cumsum_np(x, gamma):
    # much faster version of the above
    new_x = np.zeros_like(x)
    rev_cumsum = np.cumsum(np.flip(x, 0)) 
    new_x = np.flip(rev_cumsum * gamma ** np.arange(0, x.shape[0]), 0)
    new_x = np.ascontiguousarray(new_x).astype(np.float32)
    return new_x


def discount_cumsum_torch(x, gamma):
    new_x = torch.zeros_like(x)
    rev_cumsum = torch.cumsum(torch.flip(x, [0]), 0)
    new_x = torch.flip(rev_cumsum * gamma ** torch.arange(0, x.shape[0], device=x.device), [0])
    new_x = new_x.contiguous().to(dtype=torch.float32)
    return new_x


def compute_rtg_from_target(x, target_return):
    new_x = np.zeros_like(x)
    new_x[0] = target_return
    for i in range(1, x.shape[0]):
        new_x[i] = min(new_x[i - 1] - x[i - 1], target_return)
    return new_x


def split_train_valid(trajectories, p=1, trj_len_dict=None, seed=1):
    # split trajectories into train and validation sets, by trj len weights
    if trj_len_dict is not None: 
        trj_lens = [trj_len_dict[str(t)] for t in trajectories] 
    else: 
        trj_lens = [len(t["observations"]) for t in trajectories]
    p_train = 1 - p
    total_samples = sum(trj_lens)    
    trajectory_probs = [l / total_samples for l in trj_lens]
    # always performs same split via random state
    random_state = np.random.RandomState(seed=seed)
    idx = random_state.choice(len(trajectories), size=int(len(trajectories) * p_train),
                           p=trajectory_probs, replace=False)
    train_trjs = [trajectories[i] for i in idx]
    idx = set(idx)
    valid_trjs = [trajectories[i] for i in range(len(trajectories)) if i not in idx]
    return train_trjs, valid_trjs


def filter_top_p_trajectories(trajectories, top_p=1, epname_to_return=None, bottom=False):
    start = len(trajectories) - int(len(trajectories) * top_p)
    if epname_to_return is None: 
        if hasattr(trajectories[0], "rewards"):
            def sort_fn(x): return np.array(x.rewards).sum()
        else: 
            def sort_fn(x): return np.array(x.get("rewards")).sum()
    else:
        def sort_fn(x): return epname_to_return[str(x)]
    sorted_trajectories = sorted(trajectories, key=sort_fn, reverse=bottom)
    return sorted_trajectories[start:]


def filter_trajectories_uniform(trajectories, p=1):
    # sample uniformly with trj len weights
    trj_lens = [len(t["observations"]) for t in trajectories]
    total_samples = sum(trj_lens)
    trajectory_probs = [l / total_samples for l in trj_lens]
    idx = np.random.choice(len(trajectories), size=int(len(trajectories) * p), p=trajectory_probs, replace=False)
    return [trajectories[i] for i in idx]


def filter_trajectories_first(trajectories, p=1):
    return trajectories[:int(len(trajectories) * p)]


def filter_trajectories_last(trajectories, p=1):
    return trajectories[int(len(trajectories) * p): ]


def filter_trajectories_cntq(trajectories, p=1, sort=False):
    """
    Filters trajectories by their cumulative normalized trajectory quality (CNTQ). 
    First, computes the TQ of each trajectory: TQ(tau) = (TQ(tau) - TQ(D_min)) / (TQ(D_max) - TQ(D_min))
    where D_min and D_max are the minimum and maximum returns of the trajectories.
    Then computes the CNTQ: CNTQ(tau) = sum_{i=1}^{tau} TQ(tau) / sum_{i=1}^{N} TQ(tau)
    Returns the first k trajectories such that CNTQ(tau) > p
    
    Args: 
        trajectories: List of Trajectory objects. 
        p: Float. Percetntile of CNTQ cutoff. 
        sort: Bool. If True, sort trajectories by their returns before filtering.
    Returns: filtered trjs. 
    """
    if sort: 
        trajectories.sort(key=lambda trj: np.sum(trj["rewards"]), reverse=False)
    all_returns = np.array([np.sum(trj["rewards"]) for trj in trajectories])
    min_return, max_return = all_returns.min(), all_returns.max()
    trj_qualities = (all_returns - min_return) / ((max_return - min_return) + 1e-8)
    cumulative_normalized_trj_quality = np.cumsum(trj_qualities) / trj_qualities.sum()
    # get first index > p
    idx = np.argmax(cumulative_normalized_trj_quality > p)
    return trajectories[:idx]


def filter_trajectories_bucketized(trajectories, n_buckets=100, quantiles=False): 
    """
    Buckedtized filtering with buckets distributed uniformly over trj returns. 
    
    Args: 
        trajectories: List of Trajectory objects. 
        n_buckets: Int. 
    Returns: filtered trjs. 
    """
    all_returns = np.array([np.sum(trj["rewards"]) for trj in trajectories])
    n_per_bucket = len(all_returns) // n_buckets
    min_return, max_return = all_returns.min(), all_returns.max()
    if quantiles: 
        bucket_ranges = np.quantile(all_returns, np.linspace(0, 1, n_buckets + 1))
    else:
        bucket_ranges = np.linspace(min_return, max_return, n_buckets + 1)
    bucket_indices = np.digitize(all_returns, bucket_ranges, right=False) - 1
    bucket_indices = np.clip(bucket_indices, 0, n_buckets - 1)
    
    trjs, bucket_counter, bucket_contents = [], collections.defaultdict(int), collections.defaultdict(list)
    # Group trajectories by bucket
    for idx, (trj, return_) in enumerate(zip(trajectories, all_returns)):
        bucket_contents[bucket_indices[idx]].append((trj, return_))
    
    # select trjs
    for idx in range(n_buckets):
        sorted_trjs = sorted(bucket_contents[idx], key=lambda x: x[1])
        selected_indices = np.linspace(0, len(sorted_trjs) - 1, min(len(sorted_trjs), n_per_bucket), dtype=int)
        for i in selected_indices:
            trjs.append(sorted_trjs[i][0])
            bucket_counter[idx] += 1
    return trjs
    

def load_npz(path, start_idx=None, end_idx=None): 
    returns_to_go = None
    # trj = np.load(path, mmap_mode="r" if start_idx and end_idx else None)
    with np.load(path, mmap_mode="r" if start_idx and end_idx else None) as trj: 
        if start_idx is not None and end_idx is not None:
            # subtrajectory only
            observations, actions, rewards = trj["states"][start_idx: end_idx].astype(np.float32), \
                trj["actions"][start_idx: end_idx].astype(np.float32), trj["rewards"][start_idx: end_idx].astype(np.float32)
            if "returns_to_go" in trj:
                returns_to_go = trj["returns_to_go"][start_idx: end_idx].astype(np.float32)
        else: 
            # fully trajectory
            observations, actions, rewards = trj["states"], trj["actions"], trj["rewards"], 
            if "returns_to_go" in trj:
                returns_to_go = trj["returns_to_go"].astype(np.float32)
        dones = np.array([trj["dones"]])
    return observations, actions, rewards, dones, returns_to_go


def load_hdf5(path, start_idx=None, end_idx=None, img_is_encoded=False):
    returns_to_go, dones = None, None
    with h5py.File(path, "r") as f:
        if start_idx is not None and end_idx is not None:
            # subtrajectory only
            if img_is_encoded:
                observations = f['states_encoded'][start_idx: end_idx]
            else: 
                observations = f['states'][start_idx: end_idx]
            actions = f['actions'][start_idx: end_idx]
            rewards = f['rewards'][start_idx: end_idx]
            if "returns_to_go" in f:
                returns_to_go = f["returns_to_go"][start_idx: end_idx]
            if "dones" in f: 
                try:
                    dones = f['dones'][start_idx: end_idx]
                except Exception as e: 
                    pass
        else: 
            # fully trajectory
            if img_is_encoded:
                observations = f['states_encoded'][:]
            else: 
                observations = f['states'][:]
            actions = f['actions'][:]
            rewards = f['rewards'][:]
            if "returns_to_go" in f:
                returns_to_go = f["returns_to_go"][:]
            if "dones" in f:
                try:
                    dones = f['dones'][:]
                except Exception as e: 
                    pass
        if dones is None: 
            dones = np.array([f['dones'][()]])
    return observations, actions, rewards, dones, returns_to_go

    
def append_to_hdf5(path, new_vals, compress_kwargs=None):
    compress_kwargs = {"compression": "gzip", "compression_opts": 1} if compress_kwargs is None \
        else compress_kwargs
    # open in append mode, add new vals
    with h5py.File(str(path), 'a') as f:
        for k, v in new_vals.items():
            if k in f:
                del f[k]
            f.create_dataset(k, data=v, **compress_kwargs)


def load_pkl(path, start_idx=None, end_idx=None): 
    returns_to_go = None
    with open(path, "rb") as f:
        trj = pickle.load(f)
    if start_idx is not None and end_idx is not None:
        # subtrajectory only
        observations, actions, rewards = trj["states"][start_idx: end_idx], \
            trj["actions"][start_idx: end_idx], trj["rewards"][start_idx: end_idx]
        if "returns_to_go" in trj:
            returns_to_go = trj["returns_to_go"][start_idx: end_idx]
    else: 
        # fully trajectory
        observations, actions, rewards = trj["states"], trj["actions"], trj["rewards"], 
        if "returns_to_go" in trj:
            returns_to_go = trj["returns_to_go"]
    dones = np.array([trj["dones"]])    
    return observations, actions, rewards, dones, returns_to_go


def compute_start_end_context_idx(idx, seq_len, cache_len, future_cache_len, full_context_len=True, dynamic_len=False):
    start = max(0, idx - cache_len)
    end = min(seq_len, idx + future_cache_len)
    if dynamic_len: 
        start = np.random.randint(start, idx + 1)
        end = np.random.randint(idx, end + 1)
    elif full_context_len: 
        total_cache_len = cache_len + future_cache_len
        if end - start < total_cache_len:
            if start > 0:
                start -= total_cache_len - (end - start)
            else:
                end += total_cache_len - (end - start)
            start = max(0, start)
            end = min(seq_len, end)
    return start, end
