import time
import os
import torch
import numpy as np
import torch.multiprocessing as mp

from copy import deepcopy
from functools import partial

from env.mrp import MRP
from lstd import LSTD
from rf import RF
from utils import load_dict, save_dict
from utils import create_dir

########################################################################

def extract_params_lstd(params_lstd:dict)->(dict, dict):
    params=deepcopy(params_lstd)
    params_features_fun, list_params_features_fun={}, ["activation_fun", "num_features", "features_fun", "num_states"]
    for p in list_params_features_fun:
        if p in params.keys():
            if p=="num_features":
                params_features_fun["outputs"]=params[p]
            else:
                params_features_fun[p]=params[p]
            del params[p]
    if "features_fun" not in params_features_fun.keys():
        params_features_fun["features_fun"]="RF"
    return params_features_fun, params

########################################################################

def _compute_errors(V_pred:torch.Tensor, 
                    V_true:torch.Tensor, 
                    P:torch.Tensor, 
                    R:torch.Tensor,
                    stationay_distrib:torch.Tensor,
                    gamma:float)->dict:
    errors={}
    diff=V_pred-V_true
    errors["MSE"] = float((stationay_distrib@(diff**2)).cpu())
    V_next = P @ V_pred
    V_next *= gamma
    V_next += torch.sum(R * P, dim=1)
    diff = V_pred - V_next
    errors["MSBE"] = float((stationay_distrib@(diff**2)).cpu())
    return errors

def _train_LSTD_offline(i:int, 
                        path:str, 
                        path_dataset:str, 
                        params_features_fun:dict, 
                        params_lstd:dict, 
                        state_labels:torch.Tensor,
                        V_true:torch.Tensor,
                        P:torch.Tensor,
                        R:torch.Tensor,
                        stationary_distrib:torch.Tensor,
                        seed:int,
                        save_models:bool=False)->tuple:
    torch.manual_seed(seed+i)
    if params_features_fun["device"]!="cpu":
        torch.cuda.set_device(params_features_fun["device"])
    args_features_fun = deepcopy(params_features_fun)
    del args_features_fun["features_fun"]
    features_fun=dict(RF=RF)[params_features_fun["features_fun"]](**args_features_fun)
    lstd=LSTD(features_fun=features_fun, **params_lstd)
    W=lstd.learn_offline(path=path_dataset, verbose=False)   
    Phi=features_fun(state_labels)
    V_pred=Phi@W
    np.savez(path+"lstd_model_"+str(i)+"_V_pred.npz", V_pred=V_pred.cpu().numpy())
    if save_models:
        p=path+"lstd_model_"+str(i)+"_weights.npz"
        lstd.save_weights(path=p)
    errors = _compute_errors(V_pred=V_pred, 
                            V_true=V_true, 
                            P=P,
                            R=R,
                            stationay_distrib=stationary_distrib,
                            gamma=params_lstd["gamma"])
    return errors

def train_LSTD_offline(path:str, 
                       path_dataset:str, 
                       path_mrp:str,
                       params_lstd:dict, 
                       num_training:int=10, 
                       compute_V:bool=False, 
                       save_models:bool=True,
                       num_jobs:int=1,
                       seed:int=None):

    if seed is None:
        seed=0
    path_models=path+"models/"
    path_model_params=path_models+"lstd_model.json"
    create_dir(path_models)
    params_features_fun, params=extract_params_lstd(params_lstd=params_lstd)
    params_features_fun["inputs"]=np.load(path_dataset)["state"][0].shape[0]
    save_dict(path=path_model_params, data=params_features_fun)
    params_tensor=dict(dtype=torch.float64, device="cpu")
    for p in params_tensor.keys():
        if p in params:
            params_tensor[p]=params[p]
        params_features_fun[p]=params_tensor[p]
    params_mrp=load_dict(path_mrp)    
    for p in params_mrp.keys():
        params_mrp[p]=np.array(params_mrp[p])
    if "V" in params_mrp.keys():
        V_true=torch.tensor(params_mrp["V"], **params_tensor)
        del params_mrp["V"]
    mrp = MRP(**params_mrp)
    if compute_V:
        V_true=torch.tensor(mrp.compute_V(gamma=params_lstd["gamma"]), **params_tensor)
    stationary_distrib=torch.tensor(mrp.get_stationary_distribution(), **params_tensor)
    P=torch.tensor(params_mrp["P"], **params_tensor)
    R=torch.tensor(params_mrp["R"], **params_tensor)
    state_labels=torch.tensor(params_mrp["state_labels"], **params_tensor)
    del mrp
    if num_jobs>1: 
        with mp.Pool(processes=num_jobs) as pool:
            train_partial = partial(_train_LSTD_offline, 
                                    path=path_models, 
                                    path_dataset=path_dataset, 
                                    params_features_fun=params_features_fun, 
                                    params_lstd=params,
                                    state_labels=state_labels,
                                    V_true=V_true,
                                    P=P,
                                    R=R,
                                    stationary_distrib=stationary_distrib,
                                    seed=seed,
                                    save_models=save_models)                       
            res=pool.map(train_partial, range(num_training))
    else:
        res=[]
        for i in range(num_training):
            res.append(_train_LSTD_offline(i=i, 
                                path=path_models, 
                                path_dataset=path_dataset, 
                                params_features_fun=params_features_fun, 
                                params_lstd=params,
                                state_labels=state_labels,
                                V_true=V_true,
                                P=P,
                                R=R,
                                stationary_distrib=stationary_distrib,
                                seed=seed,
                                save_models=save_models))
    errors_type=res[0].keys()
    errors={}
    for e in errors_type:
        errors[e]=[]
    for r in res:
        for e in errors_type:
            errors[e].append(r[e]) 
    save_dict(path=path+"errors.json", data=errors)  