import numpy as np
import torch
from scipy.signal import lfilter

import pdb

def soft_update_network(source_network, target_network, tau):
    for target_param, local_param in zip(target_network.parameters(),
                                         source_network.parameters()):
        target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data)


def dict_batch_generator(data, batch_size, keys=None):
    if keys is None:
        keys = list(data.keys())
    num_data = len(data[keys[0]])
    num_batches = int(np.ceil(num_data / batch_size))
    indices = np.arange(num_data)
    np.random.shuffle(indices)
    for batch_id in range(num_batches):
        batch_start = batch_id * batch_size
        batch_end = min(num_data, (batch_id + 1) * batch_size)
        batch_data = {}
        for key in keys:
            batch_data[key] = data[key][indices[batch_start:batch_end]]
        yield batch_data


def dict_batch_generator_episode(data, keys=None):
    if keys is None:
        keys = list(data.keys())
    num_data = len(data[keys[0]])
    episodes = np.unique(data['episode_idx'])
    num_episodes = int(episodes.max())
    batch_size = 1  # Nb. episodes in a batch, default is 1
    indices = np.arange(num_data)
    np.random.shuffle(episodes)
    num_batches = int(num_episodes / batch_size)
    for batch_id in range(num_batches):
        batch_data = {}
        ep_idx = episodes[batch_id]
        batch_indices = np.argwhere(data['episode_idx'].flatten() == ep_idx).flatten()
        for key in keys[:-1]:  # Do not take episode_idx field
            batch_data[key] = data[key][batch_indices]
        yield batch_data

def dict_batch_generator_episode_split(data, batch_size, keys=None):
    if keys is None:
        keys = list(data.keys())
    episodes = np.unique(data['episode_idx'])
    np.random.shuffle(episodes)
    for ep_idx in episodes:
        ep_indices = np.argwhere(data['episode_idx'].flatten() == ep_idx).flatten()
        ep_len = len(ep_indices)
        num_batches = int(np.ceil(ep_len / batch_size))
        for batch_id in range(num_batches):
            batch_data = {}
            batch_indices = ep_indices[batch_id * batch_size: np.min([(batch_id + 1) * batch_size, ep_len - 1])]
            #print(f"Batch true size: {len(batch_indices)}")
            for key in keys[:-1]:  # Do not take episode_idx field
                batch_data[key] = data[key][batch_indices]
            yield batch_data

def dict_batch_generator_episode_split2(data, batch_size, keys=None):
    if keys is None:
        keys = list(data.keys())
    episodes = np.unique(data['episode_idx'])
    num_data = len(data[keys[0]])
    indices = np.arange(num_data)
    episodes_indices = data['episode_idx'].flatten()

    for ep_idx in np.unique(episodes_indices):
        mask = episodes_indices==ep_idx
        indices[mask] = np.random.permutation(indices[mask])

    np.random.shuffle(episodes)
    for ep_idx in episodes:
        ep_indices = np.where(data['episode_idx'].flatten() == ep_idx)[0]
        ep_len = len(ep_indices)
        ep_indices = indices[ep_indices]
        num_batches = int(np.ceil(ep_len / batch_size))
        for batch_id in range(num_batches):
            batch_data = {}
            batch_indices = ep_indices[batch_id * batch_size: np.min([(batch_id + 1) * batch_size, ep_len - 1])]
            for key in keys[:-1]:  # Do not take episode_idx field
                batch_data[key] = data[key][batch_indices]
            yield batch_data


def dict_batch_generator_episode_split3(data, batch_size, keys=None):
    if keys is None:
        keys = list(data.keys())
    #episodes = np.unique(data['episode_idx'])
    num_data = len(data[keys[0]])
    indices = np.arange(num_data)
    episodes_indices = data['episode_idx'].flatten()

    first_batch_ID = 0
    batch_identifier = np.zeros_like(indices)
    for ep_idx in np.unique(episodes_indices):
        mask = episodes_indices==ep_idx
        ep_len = mask.sum()
        indices[mask] = np.random.permutation(indices[mask])
        nb_batch_mask = int(np.ceil(ep_len / batch_size))
        for local_batch_ID in range(nb_batch_mask):
            batch_ID = first_batch_ID + local_batch_ID
            batch_identifier[
                np.where(mask)[0][local_batch_ID * batch_size: (local_batch_ID + 1) * batch_size]] = batch_ID
        first_batch_ID = batch_ID + 1

    batch_unique = np.unique(batch_identifier)
    np.random.shuffle(batch_unique)

    for batch_id in batch_unique:
        batch_data = {}
        batch_indices = indices[batch_identifier == batch_id]
        for key in keys[:-1]:  # Do not take episode_idx field
            batch_data[key] = data[key][batch_indices]
        yield batch_data

def dict_batch_generator_episode_split4(data, batch_size, keys=None):
    q = 0.01
    episodes = np.unique(data['episode_idx'])

    # Select users
    rand_array = np.random.rand(len(episodes))
    selected_episodes = episodes[rand_array < q]

    if keys is None:
        keys = list(data.keys())
    #episodes = np.unique(data['episode_idx'])
    num_data = len(data[keys[0]])
    indices = np.arange(num_data)
    episodes_indices = data['episode_idx'].flatten()

    first_batch_ID = 0
    batch_identifier = np.zeros_like(indices)
    for ep_idx in np.unique(episodes_indices):
        mask = episodes_indices==ep_idx
        ep_len = mask.sum()
        indices[mask] = np.random.permutation(indices[mask])
        nb_batch_mask = int(np.ceil(ep_len / batch_size))
        for local_batch_ID in range(nb_batch_mask):
            batch_ID = first_batch_ID + local_batch_ID
            batch_identifier[
                np.where(mask)[0][local_batch_ID * batch_size: (local_batch_ID + 1) * batch_size]] = batch_ID
        first_batch_ID = batch_ID + 1

    batch_unique = np.unique(batch_identifier)
    np.random.shuffle(batch_unique)

    for batch_id in batch_unique:
        batch_data = {}
        batch_indices = indices[batch_identifier == batch_id]
        for key in keys[:-1]:  # Do not take episode_idx field
            batch_data[key] = data[key][batch_indices]
        yield batch_data

def minibatch_inference(args, rollout_fn, batch_size=1000, cat_dim=0):
    data_size = len(args[0])
    print(f"Data Size: {data_size}")
    num_batches = int(np.ceil(data_size / batch_size))
    inference_results = []
    for i in range(num_batches):
        batch_start = i * batch_size
        batch_end = min(data_size, (i + 1) * batch_size)
        input_batch = [ip[batch_start:batch_end] for ip in args]
        outputs = rollout_fn(*input_batch)
        if i == 0:
            if isinstance(outputs, tuple):
                multi_op = True
            else:
                multi_op = False
            inference_results = outputs
        else:
            if multi_op:
                inference_results = (torch.cat([prev_re, op], dim=cat_dim) for prev_re, op in
                                     zip(inference_results, outputs))
            else:
                inference_results = torch.cat([inference_results, outputs])
        #print(f"Step {i}: {inference_results.device}")
    return inference_results


def merge_data_batch(data1_dict, data2_dict):
    for key in data1_dict:
        if isinstance(data1_dict[key], np.ndarray):
            data1_dict[key] = np.concatenate([data1_dict[key], data2_dict[key]], axis=0)
        elif isinstance(data1_dict[key], torch.Tensor):
            data1_dict[key] = torch.cat([data1_dict[key], data2_dict[key]], dim=0)
    return data1_dict


def discount_cum_sum(x, discount):
    return lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
