import copy
import csv
import json
import math
import random
import string
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from psutil import disk_io_counters
from scipy import signal
from scipy.stats import rankdata

DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Squeeze(nn.Module):
    def __init__(self, dim=None):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return x.squeeze(dim=self.dim)

def mlp(dims, activation=nn.ReLU, output_activation=None, squeeze_output=False):
    n_dims = len(dims)
    assert n_dims >= 2, 'MLP requires at least two dims (input and output)'

    layers = []
    for i in range(n_dims - 2):
        layers.append(nn.Linear(dims[i], dims[i+1]))
        layers.append(activation())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    if output_activation is not None:
        layers.append(output_activation())
    if squeeze_output:
        assert dims[-1] == 1
        layers.append(Squeeze(-1))
    net = nn.Sequential(*layers)
    net.to(dtype=torch.float32)
    return net

def compute_batched(f, *inputs):
    """

        Args:
            f : The function to evaluate, which returns a tensor or a list of tensors.

            Suppose output = f(arg1, arg2), and we have list_arg1, list_arg2 which we wish to batch, where
            we assume len(list_arg_1) = len(list_arg_2). Then
                tuple_outputs = compute_batched(f, list_arg_1, list_arg_2)
            where len(tuple_outputs) = len(list_arg_1).

        Returns:
            A tuple of the original outputs of f.

    """
    if len(inputs)>1:  # assert the number of
        lens =  [len(xs) for xs in inputs]
        assert all(x == lens[0] for x in lens)

    outputs = f(*[torch.cat(xs, dim=0) for xs in inputs])
    if torch.is_tensor(outputs):
        return outputs.split([len(x) for x in inputs[0]])
    else:  # suppose that's iterable.
        outputs = (o.split([len(x) for x in inputs[0]]) for o in outputs)
        return tuple(zip(*outputs))




def update_exponential_moving_average(target, source, alpha):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.mul_(1. - alpha).add_(source_param.data, alpha=alpha)

def torchify(x):
    x = torch.from_numpy(x)
    if x.dtype is torch.float64:
        x = x.float()
    x = x.to(device=DEFAULT_DEVICE)
    return x


def evaluate_policy(env, policy, max_episode_steps, deterministic=True, discount = 0.99):
    obs = env.reset()
    total_reward = 0.
    discount_total_reward = 0.
    for i in range(max_episode_steps):
        with torch.no_grad():
            try:
                action = policy.act(torchify(obs), deterministic=deterministic).cpu().numpy()
            except:
                action = policy.select_action(obs)
        next_obs, reward, done, info = env.step(action)
        total_reward += reward
        discount_total_reward += reward * discount**i
        if done:
            break
        else:
            obs = next_obs
    try:
        success = info['task_accomplished']
    except:
        success = 10
    n_steps_test = i
    return [total_reward, discount_total_reward, success,n_steps_test]


def discount_cumsum(x, discount):
    """Discounted cumulative sum.
    See https://docs.scipy.org/doc/scipy/reference/tutorial/signal.html#difference-equation-filtering  # noqa: E501
    Here, we have y[t] - discount*y[t+1] = x[t]
    or rev(y)[t] - discount*rev(y)[t-1] = rev(x)[t]
    Args:
        x (np.ndarrary): Input.
        discount (float): Discount factor.
    Returns:
        np.ndarrary: Discounted cumulative sum.
    """
    return signal.lfilter([1], [1, float(-discount)], x[::-1],
                                axis=-1)[::-1]


# Adding future reward
# Deprecated
def add_future_reward(dataset, discount_factor):

    tem = dataset['terminals']

    rew = dataset['rewards']
    reward_cuts = np.split(rew, np.arange(len(tem))[tem]+1)
    future_rewards = []
    for rw in reward_cuts:
        return_to_go = discount_cumsum(rw, discount_factor)[1:]
        return_to_go = np.append(return_to_go, 0)
        future_rewards.append(return_to_go)
    future_rewards = np.concatenate(future_rewards)
    assert len(future_rewards) == len(rew)
    dataset['future_rewards'] = future_rewards

    return dataset

def simple_lambda(future_rewards, round_threshold=0.1):
    re = []
    for i in range(len(future_rewards)):
        re += [(future_rewards<=future_rewards[i]).float().mean()]
    re = torch.stack(re)
    re[re>1-round_threshold] = 1
    re[re<round_threshold] = 0
    return re

def sample_batch(dataset, batch_size):
    k = list(dataset.keys())[0]
    n, device = len(dataset[k]), DEFAULT_DEVICE
    for v in dataset.values():
        assert len(v) == n, "Dataset values must have same length"
    indices = np.random.randint(low=0, high=n, size=(batch_size,))
    return {k: torchify(v[indices]) for k, v in dataset.items()}

def traj_to_tuple_data(traj_data, ignores=("metadata",)):
    """Concatenate a list of trajectory dicts to a dict of np.arrays of the same length."""
    tuple_data = dict()
    for k in traj_data[0].keys():
        if not any([ig in k for ig in ignores]):
            tuple_data[k] = np.concatenate([traj[k] for traj in traj_data])
    return tuple_data
def tuple_to_traj_data(tuple_data, ignores=("metadata",)):
    """Split a tuple_data dict in d4rl format to list of trajectory dicts."""
    tuple_data["timeouts"][-1] = not tuple_data["terminals"][-1]
    ends = (tuple_data["terminals"] + tuple_data["timeouts"]) > 0
    ends[-1] = False  # don't need to split at the end

    inds = np.arange(len(ends))[ends] + 1
    tmp_data = dict()
    for k, v in tuple_data.items():
        if not any([ig in k for ig in ignores]):
            tmp_data[k] = np.split(v, inds)
    traj_data = [
        dict(zip(tmp_data, t)) for t in zip(*tmp_data.values())
    ]  # convert to list of dict
    return traj_data

def traj_data_to_qlearning_data(traj_data, ignores=("metadata",)):
    """Convert a list of trajectory dicts into d4rl qlearning data format."""
    traj_data = copy.deepcopy(traj_data)
    for traj in traj_data:
        # process 'observations'
        if traj["terminals"][-1] > 0:
            traj["observations"] = np.append(
                traj["observations"], traj["observations"][-1:], axis=0
            )  # duplicate
        else:  # ends because of timeout
            for k, v in traj.items():
                if k != "observations":
                    traj[k] = v[:-1]
        # At this point, traj['observations'] should have one more element than the others.
        traj["next_observations"] = traj["observations"][1:]
        traj["observations"] = traj["observations"][:-1]
        lens = [len(v) for k, v in traj.items()]
        assert all([lens[0] == l for l in lens[1:]])

    return traj_to_tuple_data(traj_data, ignores=ignores)

def add_lambda(traj_data, heuristic_discount,discount, lambda_method):
    # get all the future returns
    future_rewards = []
    for temp_data in traj_data:
        future_rewards += [(temp_data['returns'] - temp_data['rewards'])/discount]
    future_rewards = np.concatenate(future_rewards, axis = None)
    lambdas = get_lambdas(lambda_method, future_rewards, heuristic_discount)
    del future_rewards

    # Split to trajctories
    traj_lambda = np.split(lambdas,np.cumsum(np.array([len(x['returns']) for x in traj_data]))[:-1])
    for (temp_data, temp_lambda) in zip(traj_data, traj_lambda):
        temp_data['lambda'] = temp_lambda
    del traj_lambda








def add_step_lambda(traj_data, heuristic_discount,discount,lambda_method):
    # get all the future returns
    lambda_list = []
    n_traj = len(traj_data)
    count_vector = np.ones(n_traj)>0
    count = 0
    for temp_data in traj_data:
        temp_data['lambda'] = np.ones(len(temp_data['observations']))*100 # for debug
        temp_data['future_rewards'] = (temp_data['returns'] - temp_data['rewards'])/discount
    while count_vector.sum()>0:
        count += 1
        future_rewards = []
        index_array = np.arange(n_traj)[count_vector]
        for index in index_array:
            temp_data = traj_data[index]
            if len(temp_data['lambda'])>= count:
                future_rewards+= [temp_data['future_rewards'][count-1]]
            else:
                count_vector[index] = False
        future_rewards = np.array(future_rewards)
        if len(future_rewards)>0:
            lambdas = get_lambdas(lambda_method, future_rewards, heuristic_discount)
            index_array = np.arange(n_traj)[count_vector]
            for temp_data in [traj_data[index] for index in index_array]:
                temp_data['lambda'][count-1] = lambdas[0]
                lambdas = np.delete(lambdas,0)
