# -*- coding: utf-8 -*-
"""
Created on Tue May  2 21:24:20 2023

@author: pnzha
"""

import numpy as np
from sklearn.neighbors import NearestNeighbors

class KernelRegression:
    def __init__(self, h, Kernel, M):
        self.h = h
        self.K = Kernel
        self.M = M
    
    def fit(self, X, y):
        self.X = X
        self.y = y
        self.neighbor = NearestNeighbors(radius = self.h)
        self.neighbor.fit(X)
    
    def predict(self, Xtest):
        dists, inds = self.neighbor.radius_neighbors(Xtest)
        N = len(Xtest)
        result = np.zeros(N)
        for i in range(N):
            y = self.y[inds[i]]
            dist = dists[i]
            kernel_values = np.array([self.K(dist[j]) for j in range(len(y))])
            result[i] = np.sum(y * kernel_values) / np.sum(kernel_values)
            result[i] = max(min(result[i], self.M), -self.M)
        return result

class MoMRegression:
    def __init__(self, h, Kernel, M, m):
        """
        h:kernel bandwidth
        Kernel: Kernel function
        M: truncation threshold of estimation
        m: number of bins.
        """
        self.h = h
        self.K = Kernel
        self.M = M
        self.m = m
    
    def fit(self, X, y):
        N = len(X)
        indices_shuffled = np.arange(N)
        np.random.shuffle(indices_shuffled)
        self.models = []
        for k in range(self.m):
            submodel = KernelRegression(self.h, self.K, self.M)
            ind_temp = indices_shuffled[N*k//self.m:N*(k+1)//self.m]
            submodel.fit(X[ind_temp], y[ind_temp])
            self.models.append(submodel)
    
    def predict(self, Xtest):
        values = []
        for k in range(self.m):
            temp = self.models[k].predict(Xtest)
            values.append(temp)
        result = np.median(np.vstack(values), axis = 0)
        return result

class HuberRegression:
    def __init__(self, h, Kernel, T, M):
        self.h = h
        self.K = Kernel
        self.T = T
        self.M = M
    
    def fit(self, X, y):
        self.X = X
        self.y = y
        self.neighbor = NearestNeighbors(radius = self.h)
        self.neighbor.fit(X)
    
    def predict(self, Xtest):
        '''
        binary search to find the point with zero derivative
        '''
        dists, inds = self.neighbor.radius_neighbors(Xtest)
        N = len(Xtest)
        result = np.zeros(N)
        for i in range(N):
            y = self.y[inds[i]]
            dist = dists[i]
            kernel_values = np.array([self.K(dist[j]) for j in range(len(y))])
            left = -self.M
            right = self.M
            if self.grad(left, kernel_values, y) < 0:
                result[i] = left
            elif self.grad(right, kernel_values, y) > 0:
                result[i] = right
            else:
                while right - left > 0.01:
                    mid = (left+right) / 2
                    if self.grad(mid, kernel_values, y) > 0:
                        left = mid
                    else:
                        right = mid
                result[i] = mid
        return result

    def grad(self, val, kernel_values, y):
        grad = 0
        M = len(y)
        for j in range(M):
            grad += kernel_values[j] * self.phid(y[j] - val)
        return grad    
    
    def phid(self,u):
        if abs(u) <= self.T:
            return 2*u
        elif u>self.T:
            return 2 * self.T
        else:
            return -2 * self.T
        

def projection(estimate, Delta):
    N = len(estimate)
    g = np.zeros(N)
    g_bk= np.zeros(N)
    while True:
        for i in range(N):
            if i==0:
                u = g[i+1] + Delta
                l = g[i+1] - Delta
            elif i==N-1:
                u = g[i-1] + Delta
                l = g[i-1] - Delta
            else:
                u = min(g[i+1], g[i-1]) + Delta
                l = max(g[i+1], g[i-1]) - Delta
            if l<= estimate[i] and estimate[i]<=u:
                g[i] = estimate[i]
            elif estimate[i] > u:
                g[i] = u
            else:
                g[i] = l
        if np.max(abs(g-g_bk)) < 1e-5:
            return g
        else:
            g_bk = g.copy()

def projection2d(estimate, Delta):
    N = len(estimate)
    g = np.zeros((N,N))
    g_bk = np.zeros((N,N))
    while True:
        for i in range(N):
            for j in range(N):
                u = np.inf
                l = -np.inf
                for p in [(i+1,j), (i-1,j), (i,j+1), (i,j-1)]:
                    x,y = p
                    if x>=0 and x<N and y>=0 and y<N:
                        u = min(u, g[x,y])
                        l = max(l, g[x,y])
                u = u + Delta
                l = l - Delta
                if l <= estimate[i,j] and estimate[i,j]<=u:
                    g[i,j] = estimate[i,j]
                elif estimate[i,j] > u:
                    g[i,j] = u
                else:
                    g[i,j] = l 
        if np.max(abs(g-g_bk)) < 1e-5:
            return g 
        else:
            g_bk = g.copy()