import numpy as np
import math

class RobustLinTSStruct:
	def __init__(self, featureDimension, lambda_, NoiseScale, alpha):
		self.delta = 0.01
		self.d = featureDimension

		self.XTX = np.zeros([self.d, self.d])
		self.XTy = np.zeros(self.d)
		self.NoiseScale = NoiseScale
		self.time = 1
		self.v = self.NoiseScale*np.sqrt(9 * self.d * np.log(self.time/self.delta))
        
		self.Covariance = np.linalg.inv(lambda_ * np.identity(n = self.d) + self.XTX)*self.v**2
		self.Mean = np.dot(self.Covariance / self.v**2, self.XTy)

		self.lambda_ = lambda_

		self.alpha = alpha
		self.w = 1
		self.num = 0


	def updateParameters(self, articlePicked_FeatureVector, click):
		self.time += 1

		self.v = self.NoiseScale*np.sqrt(9*self.d*np.log(self.time/self.delta))


		self.w = min(1,2*self.alpha/np.sqrt(np.dot(np.dot(articlePicked_FeatureVector , self.Covariance/self.v**2), articlePicked_FeatureVector)))
        
		if self.w != 1: self.num += 1

		self.XTX += self.w * np.outer(articlePicked_FeatureVector, articlePicked_FeatureVector)
		self.XTy += self.w * articlePicked_FeatureVector * click

		self.Covariance = np.linalg.inv(self.lambda_ * np.identity(n=self.d) + self.XTX ) * self.v**2

		self.Mean = np.dot(self.Covariance/ self.v**2, self.XTy )


	def getSample(self):
		return np.random.multivariate_normal(self.Mean, self.Covariance)

	def getTheta(self):
		return np.dot(np.linalg.inv(self.XTX+self.lambda_*np.identity(self.d)), self.XTy)

class RobustLinTS:
	def __init__(self, dimension, NoiseScale,lambda_, alpha):
		self.users = {}
		self.dimension = dimension
		self.lambda_ = lambda_
		self.NoiseScale = NoiseScale
		self.CanEstimateUserPreference = True
		self.alpha = alpha

	def decide(self, pool_articles, userID):
		if userID not in self.users:
			self.users[userID] = RobustLinTSStruct(self.dimension, self.lambda_, self.NoiseScale, self.alpha)

		maxPTA = float('-inf')
		articlePicked = None

		thetaSample = self.users[userID].getSample()

		for x in pool_articles:
			x_pta = np.dot(thetaSample, x.featureVector)

			if maxPTA < x_pta:
				articlePicked = x
				maxPTA = x_pta


		return articlePicked

	def updateParameters(self, article_picked, click, userID):
		self.users[userID].updateParameters(article_picked.featureVector, click)

	def getTheta(self, userID):
		return self.users[userID].getTheta()   

