import os, random
from buffer import *
import json
from tqdm import tqdm

action_dim=8

def sample_user(replay_buffer, number, fold, folder_path, save_path):
    n_user=0
    n_session=0
    n_request=0
    assert number <= len(os.listdir(folder_path))
    #user_list=random.sample(os.listdir(folder_path),number)
    user_list=os.listdir(folder_path)[fold*number:(fold+1)*number]
    for user in tqdm(user_list):
        file_path=os.path.join(folder_path,user)
        with open(file_path, 'r') as f:
            trajactory=json.load(f)
            try:
                state,h_state,action,next_action,next_state,next_h_state, response,h_response,done=parse_json(trajactory) # transitions in the trajactory of a user
                for array in [state,h_state,action,next_action,next_state,next_h_state, response,h_response,done]:
                    assert not np.isnan(array).any()
                replay_buffer.add(state,h_state,action,next_action,next_state,next_h_state, response,h_response,done)
                n_user+=1
                n_session+=len(trajactory)
                n_request+=len(state)
            except Exception as e:
                pass
    with open(os.path.join(save_path,f"stat-{fold}-fold.json"),"w") as f:
        json.dump({"n_user":n_user, "n_session":n_session, "n_request":n_request},f,separators=(',', ':'),indent=4)
    # The calculation is incorrect as the buffer is padded with zeros.
    # with open(os.path.join(save_path,f"stat-action-{i}-fold.json"),"w") as f:
    #     mean_a, std_a, max_a, min_a=replay_buffer.stat_actions()
    #     json.dump({"mean":mean_a, "std_a":std_a, "max":max_a,"min":min_a},f,separators=(',', ':'),indent=4)
    replay_buffer.save(os.path.join(save_path,f"{fold}-fold")) 
    return replay_buffer

def parse_json(trajactory): 
    state=[]
    h_state=[]
    action=[]
    response=[]
    h_response=[]

    for session in trajactory:
        for i, request in enumerate(session["request"]):
            state.append(request["state"])
            h_state.append(request["h_state"])
            assert len(request["action"])>=8 and len(request["action"])<=12, f'invalid action: {len(request["action"])}'
            action.append(request["action"][:action_dim])
            response.append(request["response"])
            h_response.append(request["h_response"])

    next_action=action[1:]
    next_action.append(action[-1])

    next_state=state[1:]
    next_state.append(state[-1])
    
    next_h_state=h_state[1:]
    next_h_state.append(h_state[-1])

    done=np.zeros_like(list(range(len(state))))
    done[-1]=1  
    done=np.expand_dims(done, axis=1)
    return state,h_state,action,next_action,next_state,next_h_state, response,h_response,done

def stat_actions(action, eps= 1e-3):
    shape=list(action.shape)
    mean = action.mean(0,keepdims=True)
    std= action.std(0,keepdims=False) + eps
    max_a=action.max(axis=0)
    min_a=action.min(axis=0)
    return mean.tolist(), std.tolist(), max_a.tolist(), min_a.tolist(), shape
            
