import math
import numpy as np
from tqdm import tqdm

import sys
from pathlib import Path
parent_folder = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_folder))

from utils import save_dict, load_dict

sys.path.remove(str(parent_folder))
       
##############################################################################



class MRP:
    def __init__(self, P:np.ndarray, 
                       R:np.ndarray, 
                       state_labels:np.ndarray=None, 
                       init_state_distrib:np.ndarray=None) -> None:
        assert P.shape == R.shape, "P and R must have the same shape."
        self.P = P.copy()
        self.P/=self.P.sum(axis=1, keepdims=True)
        self.R = R.copy()
        self.num_states = self.P.shape[0]
        self.state_labels = state_labels if state_labels is not None else np.arange(self.num_states)[:, np.newaxis]
        self.state_dim=self.state_labels.shape[1]
        self.init_state_distrib = init_state_distrib if init_state_distrib is not None else np.ones(self.num_states) / self.num_states
        self.terminal_states = (self.P == np.eye(self.num_states)).all(axis=1)
        self.state = None

    def get_stationary_distribution(self) -> np.ndarray:
        # Compute the eigenvectors and eigenvalues of the transpose of P
        eigenvalues, eigenvectors = np.linalg.eig(self.P.T)
        # Find the stationary vector corresponding to the eigenvalue 1
        stationary_vector = np.sum(eigenvectors[:, np.isclose(eigenvalues, 1.0)], axis=1)
        # Normalize the stationary vector to make it a proper probability distribution
        stationary_distribution = np.real(stationary_vector / stationary_vector.sum()).ravel()
        return stationary_distribution
    
    def compute_V_iter(self, gamma:float, threshold=1e-6, verbose:bool=True) -> np.ndarray:
        V = np.zeros(self.num_states)
        error=math.inf
        while error > threshold:
            V_new = np.array([sum(self.P[s, sp] * (self.R[s, sp] + gamma * V[sp]) for sp in range(self.num_states)) for s in range(self.num_states)])
            error = np.max(np.abs(V - V_new))
            if verbose:
                print("\rValue iteration error {}".format(error), end="")
            V = V_new
        if verbose:
            print("\n")
        return V
    
    def compute_V(self, gamma:float)->np.ndarray:
        return np.linalg.inv(np.eye(self.num_states)-gamma*self.P)@np.sum(self.R * self.P, axis=1)
    
    
    def reset(self) -> int:
        self.state = np.random.choice(self.num_states, p=self.init_state_distrib)
        return self.state_labels[self.state]
    
    def step(self):
        next_state = np.random.choice(self.num_states, p=self.P[self.state])
        reward = self.R[self.state, next_state]
        self.state = next_state
        return self.state_labels[next_state], reward, self.terminal_states[next_state]
    
    def compute_errors(self, V_pred:np.ndarray, gamma:float, threshold=1e-6, V_true:np.ndarray=None)->dict:
        d=self.get_stationary_distribution()
        if V_true is None:
            V_true=self.compute_value_function(gamma=gamma, threshold=threshold, verbose=False)
        errors={}
        diff=V_pred-V_true
        errors["MSE"]=np.sum(d*diff*diff)
        V_next=np.sum(self.R*self.P, axis=1)+gamma*self.P@V_pred
        diff=V_pred-V_next
        errors["MSBE"]=np.sum(d*diff*diff)
        return errors
    
    def create_datatset(self, 
                        path:str=None, 
                        num_transitions:int=10_000, 
                        sigma_reward:float=None,
                        verbose:bool=True, 
                        seed:int=None)->dict:
        if seed is not None:
            np.random.seed(seed)
        dataset=dict(state=[], reward=[], next_state=[], done=[])
        c=0
        pbar = tqdm(total=num_transitions, disable=not verbose, desc="Creating transitions")
        while c<num_transitions:
            state=self.reset() 
            done = False
            while not(done) and c<num_transitions:
                next_state, reward, done = self.step()
                dataset["state"].append(state)
                dataset["reward"].append(reward)
                dataset["next_state"].append(next_state)
                dataset["done"].append(done)
                state=next_state
                c+=1
                pbar.update(1)
        pbar.close()
        if sigma_reward is not None:
            if seed is not None:
                np.random.seed(seed)
            dataset["reward"]=np.array(dataset["reward"], dtype=np.float64)
            dataset["reward"]=dataset["reward"]+sigma_reward*np.random.randn(dataset["reward"].shape[0])
        if path is not None:
            np.savez(path, **dataset)
        return dataset       
    
##############################################################################

class Finite_MRP(MRP):
    """ 
        MRP with a finite time horizon 
    """
    
    def __init__(self, P: np.ndarray, 
                 R: np.ndarray, 
                 T:int=10, 
                 state_labels:np.ndarray=None, 
                 init_state_distrib:np.ndarray=None) -> None:
        
        super().__init__(P=P, R=R, state_labels=state_labels, init_state_distrib=init_state_distrib)
        self.T = T
        self.timesteps = 1
    
    def reset(self) -> int:
        self.timesteps = 0
        return super().reset()
    
    def step(self):
        if self.timesteps >= self.T:
            raise RuntimeError("Maximum number of steps exceeded.")
        self.timesteps += 1
        next_state, reward, done = super().step()
        return next_state, reward, (self.timesteps == self.T or done)
    
##############################################################################

def generate_reward_matrix(num_states: int) -> np.ndarray:
    v=np.random.rand(num_states,)
    return np.repeat(v, num_states).reshape(num_states, num_states)

def generate_random_matrix_V_Thomas(P:np.ndarray, ratio_noise:float) -> np.ndarray:
    num_states=P.shape[0]
    eigenvalues, eigenvectors = np.linalg.eig(P.T)
    stationary_vector = eigenvectors[:, np.isclose(eigenvalues, 1.0)]
    stationary_distrib = np.real(stationary_vector / stationary_vector.sum()).ravel()
    R = stationary_distrib + ratio_noise*np.random.randn(num_states,) ###### amount of noise controls the peak
    R /= np.linalg.norm(R)
    return np.repeat(R, num_states).reshape(num_states, num_states)
    
def generate_expected_reward_matrix(num_states: int, low:float=0., high:float=1.) -> np.ndarray:
    v=np.random.uniform(low=low, high=high, size=(num_states,))
    return np.repeat(v, num_states).reshape(num_states, num_states)

def generate_random_matrix_from_targets(target:np.ndarray, P:np.ndarray, gamma:float)->np.ndarray:
    num_states=P.shape[0]
    terminal_states = np.asarray([np.all(P[i, i] == 1) for i in range(num_states)])
    target[terminal_states]=0.
    R=(np.eye(num_states)-gamma*P)@target
    R=np.repeat(R, num_states).reshape(num_states, num_states)
    return R

##############################################################################

def create_MRP_from_MC(path:str="env/config/env.json", target:np.ndarray=None, R_V_Thomas:bool=False, state_labels:np.ndarray=None, gamma:float=0.95):
    params_mrp = load_dict(path)
    if "V" in params_mrp.keys():
        del params_mrp["V"]
    for k in params_mrp.keys():
        params_mrp[k]=np.array(params_mrp[k])
    num_states=params_mrp["P"].shape[0]
    if target is not None:
        params_mrp["R"]=generate_random_matrix_from_targets(target=target, P=params_mrp["P"], gamma=gamma)
    elif R_V_Thomas:
        params_mrp["R"]=generate_random_matrix_V_Thomas(P=params_mrp["P"], ratio_noise=2/(num_states*4))
    else:
        params_mrp["R"]=generate_expected_reward_matrix(num_states=num_states)
    ## Initial state distribution
    params_mrp["init_state_distrib"] = np.ones(num_states)/num_states
    if state_labels is not None:
        params_mrp["state_labels"]=state_labels
    ## Initialize the Markov Reward Process with the given parameters
    mrp = MRP(**params_mrp)
    V=mrp.compute_V(gamma=gamma)
    params_mrp["V"]=V
    save_dict(data=params_mrp, path=path)

##############################################################################


# if __name__ == "__main__":    
#     path="env.json"
#     gamma=0.95
#     num_states=len(load_dict(path)["P"])
#     target=None
#     d=50
#     state_labels=np.ones((num_states, d))
#     state_labels[:,1:]=np.random.randn(num_states, d-1)
#     create_MRP_from_MC(path=path, target=target, R_V_Thomas=True, state_labels=state_labels, gamma=gamma)
    
    
        

