#!/usr/bin/env python
# coding: utf-8

import numpy as np
import matplotlib.pyplot as plt
import math
from sklearn.linear_model import LinearRegression
from lib_kernel_td import generate_samples_rewards, generate_skipped_samples_rewards, kernel_TD, V, exp_avg



# ## Parameters of the MRP

eps = 0.8
gamma = 0.5

def r_abs(x):
    y = x % 1
    if y <= .5:
        return 1. - 2.*y
    else: return 2.*y - 1.
r_abs = np.vectorize(r_abs)
    
def V_star_abs(x):
    y = x % 1
    if y <= .5:
        return - 2.*y/(1. - gamma*(1.-eps)) + 1./(1-gamma) - eps*gamma/(2.*(1.-gamma)*(1. - gamma*(1.-eps)))
    else:
        return 2.*(y-1)/(1. - gamma*(1.-eps)) + 1./(1-gamma) - eps*gamma/(2.*(1.-gamma)*(1. - gamma*(1.-eps)))
V_star_abs = np.vectorize(V_star_abs)

def r_cos(x):
    y = x % 1
    return 0.5 + np.cos(y*2.*math.pi)/2.
r_cos = np.vectorize(r_cos)

def V_star_cos(x):
    y = x % 1
    return np.cos(y*2.*math.pi)/(2.*(1. - gamma*(1.-eps))) + .5/(1. - gamma)
V_star_cos = np.vectorize(V_star_cos)


# For plots
x_range = np.linspace(0, 1, 100)

# Choose the value function
r = r_abs
V_star = V_star_abs


def Kb1(u, v):
    z = (u - v) - np.floor(u - v)
    return 2.*math.pi**2 * z**2 - 2.*math.pi**2 * z + math.pi**2/3. + 1.

def Kb2(u, v):
    z = (u - v) - np.floor(u - v)
    return 1. - 1./24. * (2.*math.pi)**4 * (z**4 - 2.*z**3 + z**2 - 1./30.)  

# Choose the kernel

K = Kb2

# ## Make the experiments

np.random.seed(0)
reps = 1 # Number of repetitions, set to 10 in the experiments

n_range = np.arange(50, 2011, 50) 
MSE_range_noavg = np.zeros((reps, n_range.shape[0]))
MSE_range_avg = np.zeros((reps, n_range.shape[0]))

gd_truth = V_star(x_range)



for rep in range(reps):
    print('Repetition:', rep)
    X, rwds = generate_samples_rewards(max(n_range), eps, gamma, r)
    for idn, n in enumerate(n_range):
        print(n)
        lbda_rate = -4./7.#-0.4#-1./3.
        lbda = n**(lbda_rate) 
        rho = np.log(n)/(n*lbda)
        
        # For skipTD only
        #tau = int(np.log(1./rho) / np.log(1./(1.-eps))) + 2
        #X, rwds = generate_skipped_samples_rewards(n, tau, eps, gamma, r)
        
        icd_nmax = 0
        if n>= 1500 and K.__name__=='Kb2':
            icd_nmax = 100
        elif n>= 4000 and K.__name__=='Kb1':
            icd_nmax = 150
        alpha = kernel_TD(lbda, rho, X[:n+1], rwds[:n], K, gamma, icd_nmax = icd_nmax)
        alpha_avg = exp_avg(n, alpha, rho, lbda)

        pred_noavg = V(x_range, n-1, X, alpha, K)
        pred_avg = V(x_range, n-1, X, alpha_avg, K)

        MSE_range_noavg[rep, idn] = np.linalg.norm(pred_noavg - gd_truth)**2/x_range.shape[0]
        MSE_range_avg[rep, idn] = np.linalg.norm(pred_avg - gd_truth)**2/x_range.shape[0]

# Making the plots

plt.figure()
skip_pts = int(len(n_range)/2) # first points skipped for regression only

title_str = 'cos_'+K.__name__+'_eps'+str(eps)[:3].replace('.', ',')

regx = np.log(n_range[skip_pts:]).reshape(-1, 1)
regy = np.log(MSE_range_avg[:, skip_pts:].mean(axis=0))
reg = LinearRegression(fit_intercept=True).fit(regx, regy)
plt.title(title_str + '_ Slope: '+ str(reg.coef_[0])[:7], fontsize=15)

plt.scatter(n_range, MSE_range_noavg.mean(axis=0), color='red', alpha=0.5)
plt.fill_between(n_range, MSE_range_noavg.mean(axis=0) - MSE_range_noavg.std(axis=0),
                 MSE_range_noavg.mean(axis=0) + MSE_range_noavg.std(axis=0),
                 color ='red', alpha = 0.1)

plt.scatter(n_range, MSE_range_avg.mean(axis=0), color='blue', alpha=0.5)
plt.fill_between(n_range, MSE_range_avg.mean(axis=0) - MSE_range_avg.std(axis=0),
                 MSE_range_avg.mean(axis=0) + MSE_range_avg.std(axis=0),
                 color ='blue', alpha = 0.1)

plt.plot(1.5*n_range, np.exp(reg.coef_[0]*np.log(1.5*n_range) + reg.intercept_), color='black',
         linewidth=1, alpha=1, linestyle='--')
plt.xlim((50, 1.2*n))
plt.ylim((5e-4, 5e0))
plt.xscale('log')
plt.yscale('log')
plt.xlabel(r'$n$', fontsize=15)
plt.ylabel(r'$|| V_n - V^* ||^2_{L^2}$', fontsize=15)

plt.savefig('rate_'+title_str+'.pdf')
plt.show()




plt.figure()
plt.title(title_str)

plt.plot(x_range, V_star(x_range), linestyle='--', linewidth=1., c='black')
plt.plot(x_range, V(x_range, n-1, X, alpha_avg, K))
plt.ylim(0, 5)
plt.savefig('approx_'+title_str+'.pdf')
plt.show()



