import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.autograd import Variable
import scipy.signal


def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))
    flat_params = torch.cat(params)
    return flat_params


def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size


def get_flat_grads_from(model):
    grads = []
    for param in model.parameters():
        grads.append(param.grad.data.view(-1))
    flat_grads = torch.cat(grads)
    return flat_grads


def set_flat_grads_to(model, flat_grads):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.grad = Variable(flat_grads[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size


def print_score(episode_durations):
    durations_t = torch.FloatTensor(episode_durations)
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        print(means.numpy()[-1])


def discount(x, gamma):
    return scipy.signal.lfilter([1.0], [1.0, -gamma], x[::-1])[::-1]


def get_entropy(states, gamma):
    res = []
    for epi_states in states:
        temp = 0
        k = 0
        for s in epi_states:
            temp = gamma ** k * s + temp
            k = k + 1
        res.append(np.sum(np.log(temp+1/8)))
    return res


def calculate_scores(rewards, gamma):
    scores = []
    for epi_rewards in rewards:
        temp = list(discount(epi_rewards, gamma))
        for i in range(len(temp)):
            temp[i] *= gamma ** i
        scores.append(temp)
    return scores


def calculate_scores_IP(rewards, gamma, ratios):
    scores = []
    l = len(rewards)
    for i in range(l):
        temp = 0
        epi_scores = []
        epi_rewards = rewards[i]
        epi_ratios = ratios[i]
        m = len(epi_rewards)
        for j in range(m - 1, -1, -1):
            temp = temp + epi_ratios[j] * epi_rewards[j] * gamma ** j
            epi_scores.insert(0, temp)
        scores.append(epi_scores)
    return scores


def calculate_IP(states, actions, pol_net1, pol_net2, n_state):
    ratios = []

    for i in range(len(states)):
        epi_states = torch.reshape(Variable(torch.from_numpy(np.array(states[i])).float()), (-1, n_state))
        epi_actions = torch.reshape(Variable(torch.from_numpy(np.array(actions[i]))), (-1, 1))

        probs1 = pol_net1(epi_states).gather(1, epi_actions.long())
        probs2 = pol_net2(epi_states).gather(1, epi_actions.long())
        ratio = probs1 / probs2
        ratios.append(list(np.cumprod(ratio.data.numpy())))
    return ratios


def loglikelihood(d, a, prob):
    mean0 = prob[:, :d]
    std0 = prob[:, d:]
    return - 0.5 * (((a - mean0) / std0).pow(2)).sum(dim=1, keepdim=True) - 0.5 * np.log(
        2.0 * np.pi) * d - std0.log().sum(dim=1, keepdim=True)


def calculate_con_IP(states, actions, pol_net1, pol_net2, n_state, d):
    ratios = []

    for i in range(len(states)):
        epi_states = torch.reshape(Variable(torch.from_numpy(np.array(states[i])).float()), (-1, n_state))
        epi_actions = torch.reshape(Variable(torch.from_numpy(np.array(actions[i]))), (-1, d))

        log_probs1 = loglikelihood(d, epi_actions, pol_net1(epi_states))
        log_probs2 = loglikelihood(d, epi_actions, pol_net2(epi_states))
        ratio = log_probs1 - log_probs2
        ratios.append(list(np.exp(np.cumsum(ratio.data.numpy()))))
    return ratios


def process(data, length=1):
    res = []
    for epi_data in data:
        res += epi_data
    res = np.array(res)
    res = np.reshape(res, [np.shape(res)[0], length])
    return res


def get_action(policy_net, state, n_state):
    probs = policy_net(Variable(torch.from_numpy(np.resize(state, [1, n_state])).float()))
    m = Categorical(probs)
    action = m.sample()
    return action.data.numpy().astype(int)[0]
