import numpy as np
import matplotlib.pyplot as plt
from IPython import embed


class Estimator():
    def __init__(self, learner, env, all_phis_valid):
        self.learner = learner
        self.all_phis_valid = all_phis_valid
        Phi = self.learner.Phi
        self.Sig = Phi.T.dot(Phi)
        self.inv = np.linalg.inv(self.Sig)
        self.n_eval = 500
        self.env = env


    def stats(self, learner_eval):

        for phis in self.all_phis_valid:
            a = learner_eval.action(phis)
            v = phis[a]

            bonus = np.sqrt(learner.d) * v.dot(self.inv).dot(v)
            pred = self.learner.theta.dot(v)
            lower = pred - bonus
            upper = pred + bonus
            return lower, upper




class Learner():
    def __init__(self, K, d):
        self.K = K
        self.d = d
        self.theta = np.zeros(self.d)

    def fit(self, Phi, R):
        if len(Phi) == 0:
            self.theta = np.zeros(self.d)
            return self.theta

        Phi = np.array(Phi)
        R = np.array(R)
        Phi = Phi[:, :self.d]
        A = Phi.T.dot(Phi)
        b = Phi.T.dot(R)
        x = np.linalg.solve(A, b)
        self.theta = x
        self.Phi = Phi
        return self.theta

    def action(self, phis):
        phis = phis[:, :self.d]
        pred_rewards = phis.dot(self.theta)
        return np.argmax(pred_rewards)

class Env():

    def __init__(self, K, d):
        self.K = K
        self.d = d
        full_d = 2 * d
        self.full_d = 2 * d
        Sigmas = []
        means = []
        for a in range(K):
            A = np.random.normal(0, 1, (full_d, full_d))
            Sigma = A.T.dot(A)
            Sigmas.append(Sigma)
            means.append(np.random.normal(0, 1, full_d))

        self.Sigmas = Sigmas
        self.means = means

        self.theta = np.random.normal(0, 1, full_d)
        self.theta[d:] = 0.0
        self.sigma = 100.0
        self.rewards = None
        self.n_eval = 500


    def opt_action(self, phis):
        rewards = phis.dot(self.theta)
        return np.argmax(rewards)

    def sample(self):
        phis = []
        for a in range(self.K):
            phi = np.random.multivariate_normal(self.means[a], self.Sigmas[a])
            phis.append(phi)
        phis = np.array(phis)
        eta = np.random.normal(0, self.sigma, self.K)
        self.noiseless_rewards = phis.dot(self.theta)
        self.rewards = self.noiseless_rewards + eta 

        return phis

    def act(self, a):
        r = self.rewards[a]
        return r

    def evaluate(self, a):
        v = self.noiseless_rewards[a]
        return v

    def generate_test(self):
        all_phis = []
        all_rewards = []
        for i in range(self.n_eval):
            phis = self.sample()
            all_phis.append(phis.copy())
            all_rewards.append(self.noiseless_rewards.copy())

        self.all_phis_eval = np.array(all_phis)
        self.all_rewards_eval = np.array(all_rewards)
        return self.all_phis_eval, self.all_rewards_eval


    def eval_test_opt(self):
        vs = []
        for i in range(self.n_eval):
            v = np.max(self.all_rewards_eval[i])
            vs.append(v)
        return np.mean(vs)

    def eval_test(self, learner):
        vs = []
        for i in range(self.n_eval):
            phis = self.all_phis_eval[i]
            a = learner.action(phis)
            v = self.all_rewards_eval[i, a]
            vs.append(v)
        return np.mean(vs)




if __name__ == '__main__':
    K = 10
    d = 50
    n_eval = 500
    env = Env(K, d)

    print "Generating test..."
    env.generate_test()
    print "Done generating."
    

    def generate_phis_valid():
        all_phis_valid = []
        for i in range(n_eval):
            phis = env.sample()
            all_phis_valid.append(phis)
        return all_phis_valid

    print("Generating validation phis...")
    all_phis_valid = generate_phis_valid()
    print "Done generating validation."

    embed()

    ds = [20, 50, 100]
    M = len(ds)
    options = list(range(K))

    all_vs = []
    for k in range(M):
        n = 200
        d_model = ds[k]
        Phi = []
        R = []

        learner = Learner(K, d_model)

        vs = []
        xs = []

        print(env.theta)
        for i in range(n + 1):

            if i % 5 == 0:
                learner.fit(Phi, R)
                v = env.eval_test(learner)
                print(v)
                vs.append(v)
                xs.append(i)

            phis = env.sample()
            a = np.random.choice(options)
            phi = phis[a]
            r = env.act(a)
            Phi.append(phi)
            R.append(r)

        all_vs.append(vs)


    v_opt = env.eval_test_opt()
    vs_opt = np.ones(len(vs)) * v_opt


    plt.plot(xs, vs_opt, label='Opt')
    for k in range(M):
        plt.plot(xs, all_vs[k],label='Learner ' + str(ds[k]))
    plt.legend()
    plt.show()

    embed()



