#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan  7 11:19:03 2021

@author: pooya
"""

##############################import files

from sklearn.cluster import KMeans
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import LinearRegression
from scipy.spatial.distance import cdist
from sklearn.svm import LinearSVR
import numpy as np

##############################

class RBF_lin_reg:
    def __init__(self,K):
        self.K = K
        self.C = []
        self.sig = []
        self.model = []
        
    def fit(self, Xtr,Ytr,Xv= None, Yv = None):
        Xtr_tot = Xtr
        Ytr_tot = Ytr
        kmeans = KMeans(n_clusters=self.K, random_state=0).fit(Xtr)
        self.C = kmeans.cluster_centers_
        K = self.K
        C = self.C
        SS = np.zeros(K)
        dists0 = cdist(C,C,'minkowski', p=2)
        if Xv is None:
            N=np.shape(Xtr)[0]
            select = np.random.permutation(N)
            Xv = Xtr[select[0:int(0.2*N)],:]
            Yv = Ytr[select[0:int(0.2*N)]]
            Xtr = Xtr[select[int(0.2*N):N],:]
            Ytr = Ytr[select[int(0.2*N):N],:]
        for i in range(K):
            dists = dists0[:,i]
            dists = np.delete(dists,i,0)
            SS[i]=np.mean(dists)
        Sig = np.mean(SS)
        Sig1 = Sig/16
        Sig2 = Sig*16
        Sig_range = np.arange(Sig1,Sig2, (Sig2-Sig1)/60)
        #err0 = np.Infinity
        errv0 = np.Infinity
        #regressor = LinearRegression()
        regressor = LinearSVR(max_iter=100000)
        dists0 = cdist(Xtr,C, 'sqeuclidean')
        if Xv is not None:
            dists_v = cdist(Xv,C,'sqeuclidean')
        for j in Sig_range:
            Xtr_ker = np.exp(-dists0/(2*(j**2)))
            #Xtr_ker = Xtr
            rr = regressor.fit(Xtr_ker,Ytr.ravel())
            Y_pred = rr.predict(Xtr_ker)
            Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
            err = np.mean(np.sum((Ytr-Y_pred)**2,axis=1))
            #print(err)
            if Xv is not None:
                Xv_ker = np.exp(-dists_v/(2*(j**2)))
                Y_pred = rr.predict(Xv_ker)
                Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
                errv = np.mean(np.sum((Yv-Y_pred)**2,axis=1))
                #print('valid err = ', errv)
            if errv<errv0:
                #print('true')
                errv0=errv
                self.sig = j
                #self.model = rr
        j = self.sig
        
        dists0 = cdist(Xtr_tot,C, 'sqeuclidean')
        Xtr_ker = np.exp(-dists0/(2*(j**2)))
        rr = regressor.fit(Xtr_ker,Ytr_tot.ravel())
        self.model = rr
    
    def predict(self,X):
        dists = cdist(X,self.C,'sqeuclidean')
        X_ker = np.exp(-dists/(2*(self.sig**2)))
        Y_predict = self.model.predict(X_ker)
        return Y_predict
            