# -*- coding: utf-8 -*-
"""
Created on Tue Feb 23 11:11:30 2021

@author: grego
"""

import numpy as np
import pandas as pd
from scipy.optimize import linprog
from math import sin, cos, log

from polynomial_dist import PolynomialDist, MixtureModel
from sampling_funcs import phi_vec, sample_Z

PI = 3.14159265359
BELLMAN_CHECK_SIZE = 100  # Controls how carefully to check reward is Bellman optimal



def compute_alpha_opt(Z_list, size, cover_size, gamma, return_res = False):
    '''
    This function is used to construct and solve the linear program described in
    Algorithm 2 of the paper. 

    Parameters
    ----------
    Z_list : list
        List of 'Z'-matrices from mixture models. May be exact or estimated
    size : int
        Highest degree of trigonometric polynomial in 'Z'-matrix.  Matrices
        in 'Z_list' should be of dimension (2*size + 1) x (2*size + 1)
    cover_size : int
        Size of the covering set \bar{S}.
    gamma : float
        Discount factor. Number in the open interval (0,1).
    return_res : bool, optional
        Returns the dictionary returned by scipy.optimize.linprog if set to true.
        Otherwise, returns the reward vector alpha_opt. The default is False.

    Returns
    -------
    numpy.array
        Returns the coefficients of the optimal reward function by default.

    '''
    
    
    Z_inv = np.linalg.inv(np.eye(2*size + 1) - gamma * Z_list[0])
    
    F_list = []
    for i in range(1,len(Z_list)):
        F_list.append(Z_inv @ (Z_list[0] - Z_list[i]))
    
    
    list_F_constraints = []
    cover = np.linspace(-1, 1, num = cover_size)
      
    for s in cover:
        for F in F_list:
            list_F_constraints.append(phi_vec(s, size).T @ F.T)
    
    constraint_mat = np.vstack(list_F_constraints)
    constraint_mat = np.concatenate([constraint_mat, -1*constraint_mat], axis=1)
    
    c_vec = np.ones(constraint_mat.shape[1])
    b_vec = -1*np.ones(constraint_mat.shape[0])
    
    
    res = linprog(c_vec, A_ub=constraint_mat, b_ub=b_vec, options={'maxiter':20})
    alpha_opt = res['x'][2*size+1:] - res['x'][:2*size+1]
    
    if return_res:
        return res
    
    return alpha_opt


def check_bellman_optimality(alpha, F_list, size, num_checks):
    '''
    Empirically check that a given reward is Bellman optimal in a given MDP
    by checking that Equation 6 in the paper is fulfilled at all
    linearly spaced points on the interval [-1,1].

    Parameters
    ----------
    alpha : numpy.array
        Coefficients of an arbitrary reward function.
    F_list : numpy.array
        List of 'F'-matrices for a given MDP.
    size : int
        Highest degree trigonometric polynomial in truncated series.
        Length of alpha and dimensions of 'F' in 'F_list' should be (2*size+1).
    num_checks : int
        Number of points on [-1,1] to check.

    Returns
    -------
    correctness : bool
        Returns 'True' if the Bellman Optimality condition was fulfilled at all
        checked points; otherwise, it returns 'False'.

    '''
    
    check_set = np.linspace(-1, 1, num = num_checks)
    correctness = True
    
    for s in check_set:
        for F in F_list:
            correctness = correctness and (alpha.T @ F @ phi_vec(s, size) > 0)
            
    return correctness

def test_irl(size, gamma, cover_size, seeds, N_list, repetitions):
    '''
    This method takes a list of three seeds and generates an associated IRL problem.
    It then solves the problem using Algorithm 2 of the paper for each sampling
    parameter in 'N_list' repeated 'repetitions' times.

    Parameters
    ----------
    size : int
        Maximum degree trigonometric polynomial in series (k = 2*size + 1).
    gamma : float
        Discount factor. Number in the open interval (0,1).
    cover_size : int
        Size of the covering set \bar{S}.
    seeds : list[int]
        List of three numpy seeds used to generate the transition functions.
    N_list : list[int]
        List of  numbers to be used as values for the sampling parameter 'N'.
    repetitions : int
        Number of times the algorithm will be run for each value of 'N' in 'N_list'.

    Returns
    -------
    pandas.dataframe
        Returns a Pandas dataframe. The first column records the number of 
        samples used in the run (int), and the second column denotes
        whether the learned reward function was measured to be Bellman
        optimal (bool).
        
    '''
    
    model_list = []
    model_list.append(MixtureModel(lambda s: s**2, seeds[0]))
    model_list.append(MixtureModel(lambda s: s**2, seeds[1]))
    model_list.append(MixtureModel(lambda s: s**2, seeds[2]))
    
    np.random.seed()
    
    Z_list = [mod.compute_Zmat(size) for mod in model_list]
    Z_inv = np.linalg.inv(np.eye(2*size + 1) - gamma * Z_list[0])
    
    F_list = []
    F_list.append(Z_inv @ (Z_list[0] - Z_list[1]))
    F_list.append(Z_inv @ (Z_list[0] - Z_list[2]))
    
    # N_count = [40]
    
    alpha_hat_list = []
    
    for N in N_list:
        for rep in range(repetitions):
            print("NumSamples:", N)
            Zh_list = [sample_Z(model, size, N) for model in model_list]
            alpha_hat_list.append((N, compute_alpha_opt(Zh_list, size, cover_size, gamma)))
            
    
    alpha_hat_res = [(N, check_bellman_optimality(alpha, F_list, size, BELLMAN_CHECK_SIZE))
         for N,alpha in alpha_hat_list]
    
    return pd.DataFrame(alpha_hat_res)