import numpy as np
from collections import namedtuple, deque
import torch
import random
import tqdm
import torchvision.transforms as T
from scipy.optimize import minimize
from scipy.optimize import LinearConstraint, NonlinearConstraint
from utils import *
from env import Network, MDP
import copy
np.set_printoptions(precision=3)

def confidence_set(theta: np.ndarray, theta_hat: np.ndarray, sigma: np.ndarray):
    return -np.sum(((theta - theta_hat) ** 2) * sigma)

def confidence_set_hess(theta: np.ndarray, theta_hat: np.ndarray, sigma: np.ndarray):
    return -np.diag(sigma)

class UCB():
    def __init__(self, S: int, A: int, K: int, R: np.ndarray, gamma:float, func:Network, lamdba: float, beta: float, U: int) -> None:
        self.S = S
        self.A = A
        self.K = K
        self.R = R
        self.lamdba = lamdba
        self.gamma = gamma
        self.beta = beta
        self.U = U
        self.func = func
        self.sigma = lamdba*np.identity(K)
        self.b = np.array([0]*K)
        self.Q = np.ones((S,A)) * 1 / (1-self.gamma)
        self.V = np.ones(S) * 1 / (1-self.gamma)
        self.update_threshold = np.linalg.det(self.sigma)
        self.action_space = np.eye(A)
        self.state_space = np.eye(S)
        self.history = []
        self.theta = np.random.rand(self.K) 
        self.theta/= np.sum(self.theta)
        self.history_count = np.zeros((self.S,self.A)) # count history for each state and action
        self.theta_list = []
    def select(self, state: int) -> int:
        action = np.argmax(self.Q[state])
        self.history_count[state][action] += 1
        return action
    
    def train(self, state:int, action:int, reward:float, next_state:int)-> float: 
        self.history.append([state, action, next_state, reward])
        det_sigma = np.linalg.det(self.sigma)
        
        # compute the phi vector
        state_vec = torch.from_numpy(np.array([self.state_space[state]])).float().cuda()
        action_vec = torch.from_numpy(np.array([self.action_space[action]])).float().cuda()
        phi = self.func(state_vec, action_vec)
        phi_v = np.matmul(phi, self.V)
        self.sigma = self.sigma + np.matmul(phi_v.reshape(-1,1), phi_v.reshape(1,-1))
        self.b = self.b + phi_v * self.V[next_state]
        
        if det_sigma > 2 * self.update_threshold:
            self.update_threshold = det_sigma
            self.theta_hat = np.matmul(np.linalg.inv(self.sigma),self.b)
            self.update_value_function()
            for i in range(self.S):
                self.V[i] = max(self.Q[i])

    def update_value_function(self):
        
        next_Q = np.ones((self.S, self.A)) / (1-self.gamma)
        
        for _ in tqdm.tqdm(range(self.U)): # value iteration
            self.theta_list = []
            for s in range(self.S):
                self.V[s] = max(next_Q[s])
                for a in range(self.A):
                    state_vec = torch.from_numpy(np.array([self.state_space[s]])).float().cuda()
                    action_vec = torch.from_numpy(np.array([self.action_space[a]])).float().cuda()
                    phi = self.func(state_vec, action_vec)
                    phi_v = np.matmul(phi, self.V)
                    temp = solve_ucb_SciPy(self.S, self.A, self.func, self.sigma, self.theta_hat, self.beta, phi_v)
                    if temp is not None:
                        self.theta = temp
                        self.theta_list.append(self.theta)
                        transitions = []
                        for i in range(self.S):
                            ss = torch.from_numpy(np.array([self.state_space[i]])).float().cuda()
                            transitions.append([])
                            for j in range(self.A):
                                aa = torch.from_numpy(np.array([self.action_space[j]])).float().cuda()
                                with torch.no_grad():
                                    phi = self.func(ss,aa)
                                transitions[i].append(np.dot(np.array([self.theta]),phi)[0].tolist())
                        next_Q[s][a] = np.einsum('i,i -> ', transitions[s][a], self.R[s][a] + self.gamma * self.V) 
        self.Q = next_Q          
    
if __name__ == '__main__':
    # Define hyperparameters here
    S = 5
    A = 3
    K = 2 # len of theta
    gamma = 0.9
    seed = 1
    temperature = 1e-02
    T = 300
    Episodes = 3
    U = 100
    lamdba = 1.0
    beta = 10
    total_regrets = []
    cumulative_reward = 0
    for e in tqdm(range(Episodes)):
        if seed >= 0:
            np.random.seed(seed+e)
            torch.manual_seed(seed+e)
            random.seed(seed+e)
        theta = np.random.rand(K)
        func = Network(S+A,(S, len(theta)), temperature = 0.1)
        env = MDP(S, A, func, theta)
        learner = UCB(S, A, K, env.R, gamma, func, lamdba, beta, U)
        state = env.state
        regrets = []
        
        for t in range(T):
            action = learner.select(state)
            next_state, reward = env.step(action)
            # cumulative_reward += reward
            learner.train(state, action, next_state, reward)
            state = next_state
            # regret = env.v_star[env.state] - env.value_iteration(learner, gamma)[env.state]
            # regrets.append(regret)
    #     print(env.history)
        # total_regrets.append(regrets)
    