import math
import numpy as np
from typing import Union

from utils import create_dir, get_parent_directory
from utils import load_dict, save_dict
from env.mrp import MRP
from env.create_dataset import create_dataset_transitions

################################################################

def generate_Phi(activation_fun:str, A:np.ndarray, B:np.ndarray)->np.ndarray:
    activation_fun=activation_fun.lower()
    normA = np.sqrt(np.sum(A**2,axis=0))
    normB = np.sqrt(np.sum(B**2,axis=0))
    AB = A.T @ B
    angle_AB = np.minimum((1/normA).reshape((len(normA),1)) * AB * (1/normB).reshape( (1,len(normB)) ) ,1.)
    if activation_fun=="relu":
        Phi = 1/(2*math.pi)* normA.reshape((len(normA),1)) * (angle_AB*np.arccos(-angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) )
    return Phi

########################################################################

def generate_Phi_S(state_labels:np.ndarray, 
                   activation_fun:str="relu")->np.ndarray:
    
    Phi_S=generate_Phi(activation_fun=activation_fun, A=state_labels, B=state_labels)
    return Phi_S


########################################################################

def get_matrices_th(path_mrp:str,
                    num_transitions:int,
                    seed:int=42,
                    verbose:bool=False)->(np.ndarray, np.ndarray, np.ndarray, np.ndarray, int):
    
    params_mrp=load_dict(path_mrp)
    for k in params_mrp.keys():
        params_mrp[k]=np.array(params_mrp[k])
    if "V" in params_mrp.keys():
        del params_mrp["V"]
    num_states=params_mrp["P"].shape[0]
    mrp=MRP(**params_mrp)    
    stationary_distrib=mrp.get_stationary_distribution()
    important_states_to_visit=np.argwhere(stationary_distrib>0)[:,0]
    dataset=create_dataset_transitions(path_mrp=path_mrp, 
                                       num_transitions=num_transitions,
                                       verbose=verbose,
                                       seed=seed,
                                       get_state_id=True,
                                       save=False)
    dataset_states, dataset_next_states=np.array(dataset["state"]).squeeze(), np.array(dataset["next_state"]).squeeze()
    rtrain=np.array(dataset["reward"])
    visited=np.zeros(num_states, dtype=int)
    visited[dataset_states] = 1 
    visited[dataset_next_states] = 1
    num_important_visited_states=np.sum(visited[important_states_to_visit])
    U, V = np.zeros((num_states, num_transitions), dtype=int), np.zeros((num_states, num_transitions), dtype=int)
    U[dataset_states, np.arange(num_transitions)]=1
    V[dataset_next_states, np.arange(num_transitions)]=1
    Phi_S=generate_Phi_S(state_labels=params_mrp["state_labels"].T)
    return Phi_S, U, V, rtrain, num_important_visited_states

################################################################

def compute_delta(Phi_S:np.ndarray, 
                  U:np.ndarray,
                  V:np.ndarray,
                  num_params:int, 
                  weight_decay:float, 
                  gamma:float,
                  verbose:bool=False)->float:
    num_states, num_transitions=U.shape
    ratio=num_params/num_states
    Phi_train=(U-gamma*V).T@Phi_S@U
    eig_Phi_train=np.linalg.eigvals(Phi_train)
    delta=0
    dt=-1
    while np.abs(delta-dt)>1e-6:
        if verbose:
            print("\rError {}".format(np.abs(delta-dt)), end="")
        dt=delta
        delta=np.real(np.sum(eig_Phi_train/(ratio/(1+delta)*eig_Phi_train+weight_decay*num_transitions)))/num_states
    if verbose:
        print("\n")
    return delta

########################################################################

def get_params_mrp(path_mrp:str)->(np.ndarray, np.ndarray, np.ndarray):
    params_mrp=load_dict(path_mrp)
    del params_mrp["V"]
    for k in params_mrp.keys():
        params_mrp[k]=np.array(params_mrp[k])
    mrp=MRP(**params_mrp)
    rbar = np.sum(params_mrp["R"] * params_mrp["P"], axis=1)
    return params_mrp["P"], mrp.get_stationary_distribution(), rbar
    
################################################################

def predict_msbe_training_th(Phi_S:np.ndarray,
                       U:np.ndarray,
                       V:np.ndarray,
                       rtrain:np.ndarray,
                       num_params:int,
                       gamma:float,
                       weight_decay:float)-> float:
    
    num_states, num_transitions=U.shape
    ratio=num_params/num_states
    delta=compute_delta(Phi_S=Phi_S,
                    U=U,
                    V=V, 
                    num_params=num_params, 
                    weight_decay=weight_decay,
                    gamma=gamma,
                    verbose=False)
    Psi_S = ratio/(1+delta)*Phi_S
    Q_bar_inv = ((U-gamma*V).T@Psi_S@U)/num_transitions + weight_decay*np.eye(num_transitions)
    Q_bar = np.linalg.inv(Q_bar_inv)
    u=Q_bar@rtrain
    error=u.T@u
    u=U@u
    correction_rtrain=(u.T@Psi_S@u)/num_transitions
    M=Q_bar@(U-gamma*V).T
    correction_1=np.trace(M@Psi_S@M.T)/(num_transitions*num_params)
    M=U@M
    correction_2=np.trace(Psi_S@M.T@Psi_S@M)/(num_params*num_transitions**2)
    error_correction = (correction_rtrain*correction_1)/(1-correction_2)
    msbe_training_th=weight_decay**2*(error+error_correction)/num_transitions
    return msbe_training_th


################################################################

def predict_msbe_test_th(Phi_S:np.ndarray,
                         U:np.ndarray,
                         V:np.ndarray,
                         rtrain:np.ndarray,
                         P:np.ndarray,
                         stationary_distrib:np.ndarray,
                         rbar:np.ndarray,
                         num_params:int,
                         gamma:float,
                         weight_decay:float)-> float:
    
    num_states, num_transitions=U.shape
    ratio=num_params/num_states
    D=np.diag(stationary_distrib)
    A=(np.eye(num_states)-gamma*P)
    delta=compute_delta(Phi_S=Phi_S,
                    U=U,
                    V=V, 
                    num_params=num_params, 
                    weight_decay=weight_decay,
                    gamma=gamma,
                    verbose=False)
    Psi_S = ratio/(1+delta)*Phi_S
    Q_bar_inv = ((U-gamma*V).T@Psi_S@U)/num_transitions + weight_decay*np.eye(num_transitions)
    Q_bar = np.linalg.inv(Q_bar_inv)
    M=A@Psi_S@U@Q_bar
    u=rbar-M@rtrain/num_transitions
    error=u.T@D@u
    M=M@(U-gamma*V).T/num_transitions
    correction_trace_1=np.trace(Psi_S@M.T@D@M)
    correction_trace_2=np.trace(A.T@D@M@Psi_S)
    correction_trace_3=np.trace(Psi_S@A.T@D@A)
    correction_trace=(correction_trace_1-2*correction_trace_2+correction_trace_3)
    M=U@Q_bar
    correction_1=rtrain.T@M.T@Psi_S@M@rtrain/(num_params*num_transitions*num_transitions)
    M=M@(U-gamma*V).T/num_transitions
    correction_2=1-np.trace(Psi_S@M.T@Psi_S@M)/num_params
    correction=correction_1/correction_2
    msbe_test_th=error+correction_trace*correction
    return msbe_test_th
    
################################################################

def predict_msbe_th(params_fun:dict):
    test = "P" in params_fun.keys()
    if test:
            return predict_msbe_test_th(**params_fun)
    else:
        return predict_msbe_training_th(**params_fun)

################################################################

def predict_msbe_th_params(path_dir_logs:str,
                           path_mrp:str,
                           gamma:Union[float, np.ndarray],
                           ratio:Union[float, np.ndarray],
                           weight_decay:Union[float, np.ndarray],
                           num_transitions:int,
                           training:bool=False,
                           seed:int=42)->np.ndarray:
    create_dir(path_dir_logs)
    Phi_S, U, V, rtrain, num_important_visited_states=get_matrices_th(path_mrp=path_mrp,
                                                                      num_transitions=num_transitions,
                                                                      seed=seed,
                                                                      verbose=False)
    params_predict_msbe_th=dict(Phi_S=Phi_S, 
                                U=U, 
                                V=V, 
                                rtrain=rtrain) 
    if not(training):
        P, stationary_distrib, rbar=get_params_mrp(path_mrp=path_mrp)
        params_predict_msbe_th["P"]=P
        params_predict_msbe_th["stationary_distrib"]=stationary_distrib
        params_predict_msbe_th["rbar"]=rbar
        
    num_params=(np.round(ratio*num_important_visited_states)).astype(int)
    if type(gamma)==np.ndarray:
        params_tested_str="gamma"
        params_tested=gamma
        params_predict_msbe_th["num_params"]=num_params
        params_predict_msbe_th["weight_decay"]=weight_decay
    elif type(ratio)==np.ndarray:
        params_tested_str="num_params"
        params_tested=num_params
        params_predict_msbe_th["gamma"]=gamma
        params_predict_msbe_th["weight_decay"]=weight_decay
    elif type(weight_decay)==np.ndarray:
        params_tested_str="weight_decay"
        params_tested=weight_decay
        params_predict_msbe_th["gamma"]=gamma
        params_predict_msbe_th["num_params"]=num_params
    else:
        raise ValueError(f"No parameters tested in predict_msbe_th_params")
    msbe_th=np.zeros((params_tested.shape[0]))
    for i in range(params_tested.shape[0]):
        params_predict_msbe_th[params_tested_str]=params_tested[i]
        msbe_th[i]=predict_msbe_th(params_predict_msbe_th)
    data={}
    if params_tested_str=="num_params":
        params_tested_str="ratio"  
        params_tested=ratio
    data[params_tested_str]=params_tested   
    if training:
        data["msbe_training"]=msbe_th
        save_dict(path=path_dir_logs+"msbe_training_th.json", data=data)
    else:
        data["msbe_test"]=msbe_th
        save_dict(path=path_dir_logs+"msbe_test_th.json", data=data)
    return msbe_th

################################################################

def predict_delta_params(path_dir_logs:str,
                         path_mrp:str,
                         gamma:Union[float, np.ndarray],
                         ratio:Union[float, np.ndarray],
                         weight_decay:Union[float, np.ndarray],
                         num_transitions:int,
                         seed:int=42)->np.ndarray:
    create_dir(path_dir_logs)
    Phi_S, U, V, _, num_important_visited_states=get_matrices_th(path_mrp=path_mrp,
                                                                 num_transitions=num_transitions,
                                                                 seed=seed,
                                                                 verbose=False)
    params_predict_delta=dict(Phi_S=Phi_S, 
                              U=U, 
                              V=V,
                              verbose=False)  
    num_params=(np.round(ratio*num_important_visited_states)).astype(int)
    if type(gamma)==np.ndarray:
        params_tested_str="gamma"
        params_tested=gamma
        params_predict_delta["num_params"]=num_params
        params_predict_delta["weight_decay"]=weight_decay
    elif type(ratio)==np.ndarray:
        params_tested_str="num_params"
        params_tested=num_params
        params_predict_delta["gamma"]=gamma
        params_predict_delta["weight_decay"]=weight_decay
    elif type(weight_decay)==np.ndarray:
        params_tested_str="weight_decay"
        params_tested=weight_decay
        params_predict_delta["gamma"]=gamma
        params_predict_delta["num_params"]=num_params
    else:
        raise ValueError(f"No parameters tested in predict_msbe_th_params")
    delta=np.zeros((params_tested.shape[0]))
    for i in range(params_tested.shape[0]):
        params_predict_delta[params_tested_str]=params_tested[i]
        delta[i]=compute_delta(**params_predict_delta)
    data=dict(delta=delta)
    if params_tested_str=="num_params":
        params_tested_str="ratio"  
        params_tested=ratio
    data[params_tested_str]=params_tested   
    save_dict(path=path_dir_logs+"delta.json", data=data)
    return delta

################################################################

# if __name__ == "__main__":
#     path_mrp="res/gridworld/env.json"
#     gamma=0.95
#     seed=42
#     delta=0.01
#     ratio=np.arange(delta,2.+delta, delta)
#     list_weights_decay=[10**(-9), 10**(-8), 10**(-6), 10**(-4), 10**(-3), 10**(-2), 10**(-1), 1, 10, 100]
#     list_dataset_size=[1_000, 3_000, 5_000]
#     training=False
#     if training:
#         path_msbe=get_parent_directory(path_mrp)+"MSBE_training_th/"
#     else:
#         path_msbe=get_parent_directory(path_mrp)+"MSBE_test_th/" 
#     for num_transitions in list_dataset_size:
#         for weight_decay in list_weights_decay:
#             path_dir_logs=path_msbe+"dataset_"+str(num_transitions)+"/"+str(weight_decay)+"/"
#             predict_msbe_th_params(path_dir_logs=path_dir_logs,
#                                    path_mrp=path_mrp,
#                                    gamma=gamma,
#                                    ratio=ratio,
#                                    weight_decay=weight_decay,
#                                    num_transitions=num_transitions,
#                                    training=training,
#                                    seed=seed)
    