import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.autograd import Variable
import scipy.signal


def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))
    flat_params = torch.cat(params)
    return flat_params


def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size


def print_score(episode_durations):
    durations_t = torch.FloatTensor(episode_durations)
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        print(means.numpy()[-1])


def discount(x, gamma):
    return scipy.signal.lfilter([1.0], [1.0, -gamma], x[::-1])[::-1]


def get_cumu_discounted_rewards(rewards, gamma):
    cumu_rews = []
    for epi_rewards in rewards:
        cumu_rews.append(discount(epi_rewards, gamma)[0])
    return cumu_rews


def calculate_scores(rewards, gamma):
    scores = []
    for epi_rewards in rewards:
        temp = list(discount(epi_rewards, gamma))
        for i in range(len(temp)):
            temp[i] *= gamma**i
        scores.append(temp)
    return scores


def process(data, length=1):
    res = []
    for epi_data in data:
        res += epi_data
    res = np.array(res)
    res = np.reshape(res, [np.shape(res)[0], length])
    return res


def get_action(policy_net, state, n_state):
    probs = policy_net(Variable(torch.from_numpy(np.resize(state, [1, n_state])).float()))
    m = Categorical(probs)
    action = m.sample()
    return action.data.numpy().astype(int)[0]
