# -*- coding: utf-8 -*-
"""
Created on Wed May  3 10:55:11 2023

@author: pnzha
"""

import numpy as np
from RegressionModels import HuberRegression, KernelRegression, projection
import matplotlib.pyplot as plt

def Kernel(u):
    return 2 - abs(u)

Ntrain = 10000
X = np.random.uniform(0,1,(Ntrain,1))
y = np.sin(2*np.pi*X.ravel()) + np.random.normal(0,1,Ntrain)

plt.rcParams.update({'font.size': 12})

#Set evaluation grids
Xeval = np.linspace(0,1,101).reshape(-1,1)

#Ground truth
eta = np.sin(2*np.pi*Xeval.ravel())

attack_model = 'onedirection'
'''
Attack models:
    1. Random attack. The attacked indices are selected randomly by the attacker,
        and the modified values take from -10 and 10 randomly.
    2. One direction. The attacked indices are selected randomly by the attacker,
        and the modified values are 10.
    3. Centralized attack. The attacker attacks a specific region.
'''
if attack_model == 'random':
    q = 1000
    attack_indices = np.random.choice(np.arange(Ntrain), q, replace = False)
    y[attack_indices] = np.random.choice([-10,10], q)
elif attack_model == 'onedirection':
    q = 1000
    attack_indices = np.random.choice(np.arange(Ntrain), q, replace = False)
    y[attack_indices] = 10
elif attack_model == 'centralized':
    q = 1000
    attack1_indices = sorted(range(Ntrain), key = \
                             lambda k: abs(X.ravel()[k] - 0.25))[:q//2]
    y[attack1_indices] = 10
    attack2_indices = sorted(range(Ntrain), key = \
                             lambda k: abs(X.ravel()[k] - 0.75))[:q//2]
    y[attack2_indices] = -10

SimpleKernel = KernelRegression(0.03, Kernel, 3)
SimpleKernel.fit(X, y)
y_kernel = SimpleKernel.predict(Xeval)
plt.figure()
plt.plot(Xeval, y_kernel, linewidth = 2)
plt.plot(Xeval, eta, '--', linewidth = 2)
plt.legend(['Kernel regression', 'Ground truth'])
plt.savefig('{}_kernel.pdf'.format(attack_model))

#Initial estimate with Huber loss
Initial = HuberRegression(0.03, Kernel, 1, 3)
Initial.fit(X, y)
y_initial = Initial.predict(Xeval)
plt.figure()
plt.plot(Xeval, y_initial, linewidth = 2)
plt.plot(Xeval, eta, '--', linewidth = 2)
plt.legend(['Initial estimator', 'Ground truth'])
plt.savefig('{}_initial.pdf'.format(attack_model))

#Corrected with projection
y_corrected = projection(y_initial, 0.07)
plt.figure()
plt.plot(Xeval, y_corrected, linewidth = 2)
plt.plot(Xeval, eta, '--', linewidth = 2)
plt.legend(['Corrected estimator', 'Ground truth'])
plt.savefig('{}_final.pdf'.format(attack_model))