import numpy as np
from scipy.optimize import fmin_tnc
from numpy.linalg import inv
import random
from scipy.optimize import minimize, NonlinearConstraint

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math
import copy

# Epsilon-greedy-MNL
class eps_greedy_mnl(object):
    def __init__(self, N, K, dx, hidden_m, T, lr=1e-4, max_iter=2000,
                 eps = 0.1, eps_decay = 0.995, eps_min = 0.001):
        """
        :param N: number of items
        :param K: maximum assortment size
        :param dx: dimension of the context vectors
        :param hidden_m: number of neurons in a hidden layer
        :param T: time horizon
        :param beta: confidence radius
        :param lam: regularization parameter
        :param eps: probability of exploring choice
        """
        self.N = N
        self.K = K
        self.dx = dx
        self.hidden_m = hidden_m
        self.T = T
        self.lr = lr
        self.max_iter = max_iter
        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_min = eps_min
        self.S = None

        self.patience = 10 # Patience for optimization
        
        self.X_buffer = []
        self.Y_buffer = []
        
        # Initialization
        self.model = TwoLayerNN(self.dx, self.hidden_m).double() # initialize utility network

    def choose_S(self,t,x):
        """
        choose the epsilon-greedy assortment
        """
        utilites = []
        
        # Make epsilon-greedy style choice
        if np.random.rand() < self.eps:
            self.S = np.random.choice(self.N, 1)
        else:
            x_ = torch.tensor(x, dtype=torch.float64)
            for i in range(len(x_)):
                ins = x_[i].unsqueeze(0)
                u = self.model(ins)
                utilites.append(u.item())
            utilites = np.array(utilites)
            self.S = np.argsort(utilites)[::-1][:self.K].copy()

        # Decay epsilon
        self.eps = max(self.eps_min, self.eps * self.eps_decay)

        # Append X to X buffer
        self.X_buffer.append(x[self.S])

        return(self.S)

    def update_w(self, Y, t):

        # Append Y to Y buffer
        self.Y_buffer.append(Y)

        # Check if epoch changed
        epoch_change = ((t + 1) & t) == 0

        if epoch_change:

            # Update model
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        
            for _ in range(self.max_iter):
                optimizer.zero_grad()
                
                loss = 0.0
                
                for i in range(len(self.X_buffer)):
                    x_i = torch.tensor(self.X_buffer[i], dtype=torch.float64)
                    y_i = torch.tensor(self.Y_buffer[i], dtype=torch.float64)
                    
                    u = self.model(x_i)
                    logits = torch.cat([u, torch.tensor([0.0], dtype=torch.float64)])
                    log_probs = torch.log_softmax(logits, dim=0)
                
                    loss += -torch.sum(y_i*log_probs)
                loss.backward()
                optimizer.step()
                
                # Early stopping
                best_loss = float('inf')
                counter = 0
                
                if loss.item() + 1e-6 < best_loss:  # Little margin for floating-point stability
                    best_loss = loss.item()
                    counter = 0
                else:
                    counter += 1
                    if counter >= self.patience:
                        print(f"Early stopping at iteration {i+1}, no improvement for {self.patience} steps.")
                        break

            # Discard all the X,Y pairs in buffer
            self.X_buffer = []
            self.Y_buffer = []
    
# ofu_mnl_plus
class ofu_mnl_plus(object):
    def __init__(self, N, K, d, kappa = 0.5, beta = None, vzero = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param S: upper bound of L2 norm of each w
        :param eta: step-size parameter for online update
        :param r_lambda: regularization parameter
        """
        # immediate attributes from the constructor
        self.N = N
        self.d = d
        self.S = 1
        self.K = K
        self.eta = (self.S + 1) + np.log(self.K+1)/2
        self.r_lambda = 84 * np.sqrt(2) * self.eta * self.d
        # init parameter
        self.t = 1
        self.W_t = np.zeros(self.d)[:, None]  # estimated parameter
        self.H = self.r_lambda * np.identity(self.d)  # hessian of loss matrix
        self.inv_H = 1/self.r_lambda * np.identity(self.d)  # inverse of H
        self.beta = beta  # confidence radius
        self.vzero = vzero  # utility for the outside option
        
    def choose_S(self, t, x):
        """
        choose the optimistic assortment
        """
        if self.beta is None:
            ## calculate beta
            self.beta = np.sqrt(2 *self.eta *( (3*np.log(1 + (self.K +1)*t ) +3)
            *(17/16*self.r_lambda + 2* np.sqrt(self.r_lambda)*np.log(2*np.sqrt(1+2*t)) + 16*(np.log(2*np.sqrt(1+2*t)))**2 )
            + 2 + np.sqrt(6)*7/6*self.eta*self.d*np.log(1 + (t+1)/(2*self.r_lambda))) + 4*self.r_lambda)   
        means = np.squeeze(np.dot(x,self.W_t))
        xv = np.sqrt((np.matmul(x, self.inv_H) * x).sum(axis = 1))
        u = means + self.beta * xv
        self.assortment = np.argsort(u)[::-1][:self.K]
        self.chosen_vectors = x[self.assortment,:]
        return(self.assortment)

    def update_state(self, y):
        """
        update state
        """
        X = self.chosen_vectors
        assert isinstance(X, np.ndarray), 'np.array required'
        self.update_estimator(X, y)
        probs = self.Sigma(self.W_t, X)
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term 
        self.H += gg_Wt
        self.inv_H = inv(self.H)
        self.t += 1
        
    def update_estimator(self, X, y):
        """
        update parameter
        """
        y = np.squeeze(y[:-1])
        W_estimate = self.W_t
        probs = np.squeeze(self.Sigma(self.W_t, X))
        g_Wt = np.sum(np.multiply(np.repeat((probs - y)[...,np.newaxis], self.d, axis=1), X), axis=0) # d dimension
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term
        M_t = 1/(2*self.eta) * self.H + 1/2 * gg_Wt
        inv_Mt = inv(M_t)
        unprojected_update = np.squeeze(W_estimate) - np.dot(inv_Mt,g_Wt)
        if np.linalg.norm(unprojected_update) > self.S:
            if self.K == 1:
                W_estimate = self.S * unprojected_update / np.linalg.norm(unprojected_update)
            else:
                W_estimate = self.projection(unprojected_update, M_t)[:,None]
            self.W_t = W_estimate
        else:
            self.W_t = unprojected_update
        self.W_t = unprojected_update

    def Sigma(self, W, X):
        """
        calculate MNL probability 
        """
        z = np.matmul(X, W)
        sigma = np.exp(z)
        sigma = sigma / (sigma.sum(axis=0) + self.vzero)
        return sigma

    def proj_fun(self, W, un_projected, M):
        diff = W-un_projected
        fun = np.dot(diff, np.dot(M, diff))
        return fun

    def projection(self, unprojected, M):
        fun = lambda t: self.proj_fun(t, unprojected, M)
        constraints = []
        norm = lambda t: np.linalg.norm(t[self.d :self.d  + self.d])
        constraint = NonlinearConstraint(norm, 0, self.S)
        constraints.append(constraint)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraints)
        return opt.x

#UCB-MNL
class ucb_mnl:
    def __init__(self, N, K, d, kappa = 0.5, alpha=None, lam = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param X: set of contexts
        :param Y: set of choice feedbacks
        :param S: upper bound of L2 norm of each w
        :param kappa: degree of non-linearlity
        :param lam: regularization parameter
        """
        super(ucb_mnl, self).__init__()
        self.N = N
        self.K = K
        self.d = d
        self.X = np.zeros((K,d))[np.newaxis, ...]
        self.Y = np.zeros(K+1)[np.newaxis, ...]
        self.S = 1
        self.kappa = kappa
        self.lam = lam
        # init parameter
        self.theta = np.zeros(d)  # estimated parameter
        self.V = np.eye(d)*lam  # grammatrix
        self.mnl = RegularizedMNLRegression()  # MLE loss function
        self.alpha = alpha  # confidence radius
        
    def choose_S(self,t,x):  # x is N*d matrix
        """
        choose the optimistic assortment
        """
        if self.alpha is None:
            self.alpha = (1/(2*self.kappa))*np.sqrt(2*self.d*np.log(1+t/self.d)+2*np.log(t))
        means = np.dot(x,self.theta)
        xv = np.sqrt((np.matmul(x, inv(self.V)) * x).sum(axis = 1))
        u = means + self.alpha * xv
        self.S = np.argsort(u)[::-1][:self.K]
        self.X = np.concatenate((self.X, x[self.S,:][np.newaxis, ...]))
        self.V += np.matmul(x[self.S,:].T, x[self.S,:])
        return(self.S)

    def update_theta(self,Y,t):
        """
        update parameter
        """
        self.Y = np.concatenate((self.Y, Y[np.newaxis, ...]))
        if t==2:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)
        self.mnl.fit(self.X, self.Y, self.theta, self.lam)
        self.theta = self.mnl.w

# TS-MNL with Gaussian approximation
class ts_mnl:
    def __init__(self, N, K, d, kappa = 0.5, alpha=None, lam = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param X: set of contexts
        :param Y: set of choice feedbacks
        :param S: upper bound of L2 norm of each w
        :param kappa: degree of non-linearlity
        :param lam: regularization parameter
        """
        super(ts_mnl, self).__init__()
        self.N=N
        self.K=K
        self.d=d
        self.X=np.zeros((K,d))[np.newaxis, ...]
        self.Y=np.zeros(K+1)[np.newaxis, ...]
        self.S = 1
        self.kappa = kappa
        self.lam = lam  

        # init parameter
        self.theta=np.zeros(d)  # estimated parameter
        self.V=np.eye(d)*lam  # grammatrix
        self.mnl=RegularizedMNLRegression()  # MLE loss function
        self.alpha=alpha  # confidence radius
        
    def choose_S(self,t,x):  # x is N*d matrix
        """
        choose the optimistic assortment
        """
        if self.alpha is None:
            self.alpha = (1/(2*self.kappa))*np.sqrt(2*self.d*np.log(1+t/self.d)+2*np.log(t))
        theta_tilde = np.random.multivariate_normal(self.theta, np.square(self.alpha)*inv(self.V))
        means = np.dot(x,theta_tilde)            
        self.S = np.argsort(means)[::-1][:self.K]
        self.X = np.concatenate((self.X, x[self.S,:][np.newaxis, ...]))
        self.V += np.matmul(x[self.S,:].T, x[self.S,:])
        return(self.S)

    def update_theta(self,Y,t):
        """
        update parameter
        """
        self.Y = np.concatenate((self.Y, Y[np.newaxis, ...]))
        if t==2:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)
        self.mnl.fit(self.X, self.Y, self.theta, self.lam)
        self.theta = self.mnl.w

class RegularizedMNLRegression:

    def compute_prob(self, theta, x):
        means = np.dot(x, theta)
        u = np.exp(means)
        u_ones = np.column_stack((u,np.ones(u.shape[0])))
        logSumExp = u_ones.sum(axis=1)
        prob = u_ones/logSumExp[:,None]
        return prob

    def cost_function(self, theta, x, y, lam):
        m = x.shape[0]
        prob = self.compute_prob(theta, x)
        return -(1/m)*np.sum( np.multiply(y, np.log(prob))) + (1/m)*lam*np.linalg.norm(theta)

    def gradient(self, theta, x, y, lam):
        m = x.shape[0]
        prob = self.compute_prob(theta, x)
        eps = (prob-y)[:,:-1]
        grad = (1/m)*np.tensordot(eps,x,axes=([1,0],[1,0])) + (1/m)*lam*theta
        return grad

    def fit(self, x, y, theta, lam):
        opt_weights = fmin_tnc(func=self.cost_function, x0=theta, fprime=self.gradient, args=(x, y, lam), disp=False)
        self.w = opt_weights[0]
        return self

## ONL-MNL
class TwoLayerNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.dx = input_dim
        self.hidden_m = hidden_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float64)
        elif x.dtype != torch.float64:
            x = x.to(dtype=torch.float64)
            
        x = torch.sigmoid(self.fc1(x))
        return self.fc2(x).squeeze(-1)
    
    def set_model_weights_from_vector(self, w):
        w = torch.tensor(w, dtype=torch.float64) if not isinstance(w, torch.Tensor) else w.double()
        fc1_w_end = self.dx * self.hidden_m
        fc1_b_end = fc1_w_end + self.hidden_m
        fc2_w_end = fc1_b_end + self.hidden_m

        with torch.no_grad():
            self.fc1.weight.copy_(w[:fc1_w_end].view(self.hidden_m, self.dx))
            self.fc1.bias.copy_(w[fc1_w_end:fc1_b_end])
            self.fc2.weight.copy_(w[fc1_b_end:fc2_w_end].view(1, self.hidden_m))
            self.fc2.bias.copy_(w[fc2_w_end:])
            
class LinearizedRegularizedMNL:
    def compute_prob(self, w, f_vals, grad_vals, w_s):        
        logits = f_vals + grad_vals @ (w - w_s)
        logits = np.append(logits, 0.0)
        logits -= np.max(logits)
        exp_logits = np.exp(logits)
        prob = exp_logits / np.sum(exp_logits)
        return prob
    
    def cost_function(self, w, f_all, grad_all, y_all, w_all, w0, lam):
        loss = 0.0
        for f_vals, grad_vals, y_vec, w_s in zip(f_all, grad_all, y_all, w_all):
            prob = self.compute_prob(w, f_vals, grad_vals, w_s)
            loss += -np.sum(y_vec * np.log(prob))
        reg = (lam / 2) * np.sum((w - w0)**2)
        return loss + reg    

    def gradient(self, w, f_all, grad_all, y_all, w_all, w0, lam):
        grad_total = np.zeros_like(w)
        for f_vals, grad_vals, y_vec, w_s in zip(f_all, grad_all, y_all, w_all):
            prob = self.compute_prob(w, f_vals, grad_vals, w_s)
            eps = prob - y_vec
            grad = eps[:-1] @ grad_vals
            grad_total += grad
        grad_total += lam * (w - w0)
        return grad_total

    def fit(self, f_all, grad_all, y_all, w_all, w0, lam, w):
        result = fmin_tnc(func=self.cost_function, x0=w, fprime=self.gradient, args=(f_all, grad_all, y_all, w_all, w0, lam), disp=False)
        self.w = result[0]
        return self

class onl_mnl:
    def __init__(self, N, K, dx, hidden_m, n, T, kappa = 0.01, lr=1e-4, max_iter=2000, 
                 c_lam=7.2e-10, c_beta=1e-15, c_h = 3):
        """
        :param N: number of items
        :param K: maximum assortment size
        :param dx: dimension of the context vectors
        :param hidden_m: number of neurons in a hidden layer
        :param n: number of Phase I
        :param T: number of Phase II
        :param kappa: degree of non-linearlity
        :param lam: regularization parameter
        :param beta: confidence radius
        :param c_lam: constant to scale lam
        :param c_beta: constant to scale beta
        :param c_h: constant to scale second bonus term b_2
        """
        self.N = N
        self.K = K
        self.dx = dx
        self.hidden_m = hidden_m
        self.n = n
        self.T = T
        self.kappa = kappa
        self.lr = lr
        self.max_iter = max_iter
        self.c_beta = c_beta
        self.c_lam = c_lam
        self.c_h = c_h
        
        self.dw = dx * hidden_m + (2 * hidden_m) + 1 # Number of parameters of the 2-layered neural network
        self.lam = self.c_lam * ((1/self.kappa)**2.5) * (self.dw) * (np.sqrt(self.T))
        
        self.patience = 10 # Patience for Phase I optimization
        
        self.X_buffer = [] # X buffer for Phase I
        self.Y_phase1 = [] # Y buffer for Phase II
        
        self.grad_buffer = [] # Gradient of X buffer for Phase II
        self.Y_phase2 = [] # Y buffer for Phase II
        self.f_buffer = [] # Estimated utility buffer for Phase II
        self.w_buffer = [] # Parameter buffer for Phase II
        
        # Initialization
        self.model = TwoLayerNN(self.dx, self.hidden_m).double() # Initialize utility network
        
        self.V_inv = (1/self.lam)*torch.eye(self.dw).double()
        self.mnl = LinearizedRegularizedMNL()  # MLE trainer for Phase II
        self.w0 = None # Pilot estimator
        self.w = torch.zeros(self.dw).double() # Estimator for Phase II

    def set_model_weights_from_vector(self, model, w):
        w = torch.tensor(w, dtype=torch.float64) if not isinstance(w, torch.Tensor) else w.double()
        fc1_w_end = self.dx * self.hidden_m
        fc1_b_end = fc1_w_end + self.hidden_m
        fc2_w_end = fc1_b_end + self.hidden_m

        with torch.no_grad():
            model.fc1.weight.copy_(w[:fc1_w_end].view(self.hidden_m, self.dx))
            model.fc1.bias.copy_(w[fc1_w_end:fc1_b_end])
            model.fc2.weight.copy_(w[fc1_b_end:fc2_w_end].view(1, self.hidden_m))
            model.fc2.bias.copy_(w[fc2_w_end:])
        
    def choose_S(self,t,x):
        """
        choose the optimistic assortment
        """
        ## Phase I ##################################################
        if t < self.n+1:
            self.S = np.random.choice(self.N, self.K, replace=False)
            self.X_buffer.append(x[self.S])
        #############################################################            
        
        ## Phase II #######################################################################################
        elif t > self.n:
            
            self.beta = self.c_beta * (1/self.kappa**4) * self.dw * (t / self.T)
            self.beta_ = self.c_beta * (1/self.kappa**4) * self.dw
        
            b_2 = self.c_h * self.beta_/self.lam
                
            # Compute the estimated utility and confidence bonus
            x_ = torch.tensor(x, dtype=torch.float64)
            
            means, grads, bonus = [], [], []
            
            for i in range(len(x_)):
                ins = x_[i].unsqueeze(0)
                u = self.model(ins)
                means.append(u.item())
                
                self.model.zero_grad()
                u.backward()
                
                grad = torch.cat([p.grad.view(-1) for p in self.model.parameters() if p.grad is not None])
                grads.append(grad.detach().clone()) # detach from the computation graph and save a clone
                
                b_1 = np.sqrt(self.beta)*torch.sqrt(grad @ self.V_inv @ grad).item()

                bonus.append(b_1 + b_2)
                
            means = np.array(means)
            bonus = np.array(bonus)
            grads_tensor = torch.stack(grads)
            
            # Upper confidence bound (z_{ti}) is computed by u + b1 + b2
            z = means + bonus
            self.S = np.argsort(z)[::-1][:self.K].copy()
            self.f_buffer.append(means[self.S].reshape(-1).astype(np.float64))
            self.grad_buffer.append(grads_tensor[self.S].clone().numpy().astype(np.float64))

            # Sherman-Morrison formula to update the Gram matrix inverse
            for idx in self.S:
                g = grads_tensor[idx].unsqueeze(1)
                
                Vg = self.V_inv @ g
                denom = 1.0 + (g.T @ Vg).item()
                self.V_inv -= (Vg @ Vg.T) / denom
        ###################################################################################################
        
        return(self.S)

    def update_w(self, Y, t):
        ### Phase I ###############################################################################################
        if t < self.n:
            self.Y_phase1.append(Y)
        elif t == self.n:
            self.Y_phase1.append(Y)
            
            # Update w_0
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        
            for _ in range(self.max_iter):
                optimizer.zero_grad()
                
                loss = 0.0
                
                for i in range(len(self.X_buffer)):
                    x_i = torch.tensor(self.X_buffer[i], dtype=torch.float64)
                    y_i = torch.tensor(self.Y_phase1[i], dtype=torch.float64)
                    
                    u = self.model(x_i)
                    logits = torch.cat([u, torch.tensor([0.0], dtype=torch.float64)])
                    log_probs = torch.log_softmax(logits, dim=0)
                
                    loss += -torch.sum(y_i*log_probs)
                loss.backward()
                optimizer.step()
                
                # Early stopping
                best_loss = float('inf')
                counter = 0
                
                if loss.item() + 1e-6 < best_loss:  # Little margin for floating-point stability
                    best_loss = loss.item()
                    counter = 0
                else:
                    counter += 1
                    if counter >= self.patience:
                        print(f"Early stopping at iteration {i+1}, no improvement for {self.patience} steps.")
                        break
        

            self.w0 = torch.cat([p.data.view(-1) for p in self.model.parameters()]).cpu().numpy().astype(np.float64)
            self.w_buffer.append(self.w0)
        ###########################################################################################################
        
        ### Phase II #########################################################################################
        elif t > self.n:
            self.Y_phase2.append(Y.astype(np.float64))
            w_prev = self.w.cpu().numpy().astype(np.float64)
            self.mnl.fit(self.f_buffer, self.grad_buffer, self.Y_phase2, self.w_buffer, self.w0, self.lam, w_prev)
            self.w = torch.tensor(self.mnl.w, dtype=torch.float64)
            self.w_buffer.append(self.w.numpy().astype(np.float64))
            
            self.set_model_weights_from_vector(self.model, self.w)
        ######################################################################################################
