import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import minimize, minimize_scalar, bisect
from sklearn.datasets import load_svmlight_file
from tqdm import tqdm

def logistic_loss(p, y):
    return -p[y] + np.log(np.sum(np.exp(p)))

def sig(x):
    return np.exp(x) / np.sum(np.exp(x))

def sig_batch(x):
    return np.exp(x) / np.sum(np.exp(x), axis=1)[:,None]

def e(i, K):
    ei = np.zeros(K)
    ei[i] = 1
    return ei

class OGD:
    def __init__(self, d, K, lbd=1.):
        self.d = d
        self.K = K
        self.lbd = lbd
        self.theta = np.zeros(K*d)

    def phi(self, x):
        phi_x = np.zeros((self.d * self.K, self.K))
        for i in range(self.K):
            phi_x[self.d * i:self.d * (i + 1), i] = x
        return phi_x

    def update(self, x, y):
        g = self.phi(x) @(-e(y,self.K) + sig(self.theta @ self.phi(x)))
        self.theta -= g / self.lbd

    def predict(self, x):
        return self.theta @ self.phi(x)

class ONS:
    def __init__(self, d, K, beta=1., lbd=1.):
        self.d = d
        self.K = K
        self.beta = beta
        self.A = lbd * np.eye(K*d)
        self.theta = np.zeros(K*d)

    def phi(self, x):
        phi_x = np.zeros((self.d * self.K, self.K))
        for i in range(self.K):
            phi_x[self.d * i:self.d * (i + 1), i] = x
        return phi_x

    def update(self, x, y):
        g = self.phi(x) @(-e(y,self.K) + sig(self.theta @ self.phi(x)))
        self.A += self.beta * np.outer(g, g)
        self.theta -= np.linalg.pinv(self.A) @ g

    def predict(self, x):
        return self.theta @ self.phi(x)

class GAF:
    def __init__(self, d, K, beta=1., lbd=1., proper=False, nb_samples=100):
        self.A = lbd * np.eye(K * d)
        self.b = np.zeros(K * d)
        self.theta = np.zeros(K * d)
        self.K = K
        self.d = d
        self.beta = beta
        self.t = 1
        self.proper = proper
        self.nb_samples = nb_samples

    def phi(self, x):
        phi_x = np.zeros((self.d * self.K, self.K))
        for i in range(self.K):
            phi_x[self.d * i:self.d * (i + 1), i] = x
        return phi_x

    def update(self, x, y):
        phi_x = self.phi(x)
        L = lambda theta: self.b @ theta + theta @ self.A @ theta + logistic_loss(theta @ phi_x, y)
        Lp = lambda theta: self.b + 2 * self.A @ theta + self.phi(x) @(-e(y,self.K) + sig(theta @ self.phi(x)))
        opt = minimize(L, self.theta, jac=Lp, method='BFGS')
        self.theta = opt.x
        p = sig(self.theta @ phi_x)
        H = self.beta * phi_x @ (np.diag(p) - np.outer(p, p)) @ phi_x.T
        self.b += -phi_x @ e(y, self.K) + phi_x @ p - 2 * self.theta @ H 
        self.A += H
        self.t += 1

    def predict(self, x):
        phi_x = self.phi(x)
        samples = np.random.multivariate_normal(self.theta @ phi_x,
                                                phi_x.T @ np.linalg.inv(self.A) @ phi_x,
                                                size=self.nb_samples)
        smooth = lambda p: (1-1/self.t)*p + np.ones(self.K)/(self.t*self.K)
        if not self.proper:
            return np.log(np.mean(smooth(sig_batch(samples)), axis=0))
        else:
            return self.theta @ phi_x

class oracle:
    def __init__(self,d,K,X,Y):
        def L(theta):
            theta = np.reshape(theta, (d,K))
            p = X @ theta
            p_y = np.array(list(p[t,Y[t]] for t in range(p.shape[0])))
            return np.sum(-p_y + np.log(np.sum(np.exp(p),axis=1)))
        def Lp(theta):
            theta = np.reshape(theta, (d,K))
            e_y = np.array(list(e(Y[t],K) for t in range(X.shape[0])))
            return np.sum(X*(-e_y + sig_batch(X @ theta)), axis=0)
        self.theta = np.reshape(minimize(L, np.zeros(d*K), jac=Lp,  method='BFGS').x, (d,K))

    def update(self, x, y):
        pass

    def predict(self, x):
        return x @ self.theta

class zero_forecaster:
    def __init__(self,K):
        self.K = K

    def update(self, x, y):
        pass

    def predict(self, x):
        return np.zeros(self.K)

def run_on_dataset(forecaster,X,Y):
    L = np.zeros((len(Y)))

    for t in tqdm(range(len(Y))):
        pred = forecaster.predict(X[t,:])
        L[t:] += logistic_loss(pred, Y[t])
        forecaster.update(X[t,:],Y[t])

    return L