import numpy as np
import quadprog

from opelab.core.baseline import Baseline
from opelab.core.data import DataType
from opelab.core.policy import Policy

    
def quadratic_solver(n, M, regularizer):
    qp_G = np.matmul(M, M.T)
    qp_G += regularizer * np.eye(n)    
    qp_a = np.zeros(n, dtype=np.float64)
    qp_C = np.zeros((n, n + 1), dtype=np.float64)
    for i in range(n):
        qp_C[i, 0] = 1.0
        qp_C[i, i + 1] = 1.0
    qp_b = np.zeros(n + 1, dtype=np.float64)
    qp_b[0] = 1.0
    meq = 1
    res = quadprog.solve_qp(qp_G, qp_a, qp_C, qp_b, meq)
    w = res[0].reshape((-1,))
    return w


class IPSDiscrete(Baseline):
    
    def __init__(self, num_state:int, regularizer:float=0.001) -> None:
        self.num_state = num_state
        self.regularizer = regularizer

    def _train_density(self, data, target, behavior, gamma=1.0): 
        gmat = np.zeros([self.num_state, self.num_state], dtype=np.float64)
        nstate = np.zeros([self.num_state, 1], dtype=np.float64)
        for tau in data:
            discounted_t = 1.0
            initial_state = tau['states'][0]
            for state, action, next_state in zip(tau['states'], tau['actions'], tau['next-states']):
                discounted_t *= gamma
                policy_ratio = target.prob(state, action) / behavior.prob(state, action)
                gmat[state, next_state] += policy_ratio * discounted_t
                gmat[state, initial_state] += (1 - gamma) / gamma * discounted_t
                gmat[next_state, next_state] -= discounted_t
                nstate[state] += discounted_t
            gmat[initial_state, initial_state] -= discounted_t            
        frequency = nstate.reshape(-1)
        tvalid = np.where(frequency >= 1e-20)
        frequency = frequency / np.sum(frequency)
        G = np.zeros_like(gmat)        
        G[tvalid] = gmat[tvalid] / (frequency[:, None])[tvalid]
        x = quadratic_solver(self.num_state, G / 50.0, self.regularizer)
        w = np.zeros(self.num_state)
        w[tvalid] = x[tvalid] / frequency[tvalid]
        return x, w
        
    def evaluate(self, data:DataType, target:Policy, behavior:Policy, gamma:float=1.0) -> float: 
        _, w = self._train_density(data, target, behavior, gamma)        
        total_reward = 0.0
        normalizer = 0.0
        for tau in data:
            discounted_t = 1.0
            for state, action, reward in zip(tau['states'], tau['actions'], tau['rewards']):
                policy_ratio = target.prob(state, action) / behavior.prob(state, action)
                w_ratio = w[state] * policy_ratio
                total_reward += w_ratio * discounted_t * reward
                normalizer += w_ratio * discounted_t
                discounted_t *= gamma
        return total_reward / normalizer

