import json
from torch.distributions.normal import Normal
from data_process import *
from utils import process_reward_eva

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

std=torch.tensor([0.1,0.1]).to(device)

def parse_his(his):
        state = []
        h_state = []
        action = []
        response = []
        h_response = []
        for t in range(len(his)):
            s,a,r=his[t]
            state.append(s)
            h_state.append(s)
            action.append(a)
            response.append([r])
            h_response.append([r])
        
        next_action = action[1:]
        next_action.append(action[-1])

        next_state = state[1:]
        next_state.append(state[-1])

        next_h_state = h_state[1:]
        next_h_state.append(h_state[-1])

        done = np.zeros_like(list(range(len(state))))
        done[-1] = 1
        done = np.expand_dims(done, axis=1)
        return state, h_state, action, next_action, next_state, next_h_state, response, h_response, done

def w_offline_ab(policy, his, immediate=True):
    # calculate sequentially, could be calculate in batch as well.
    R=0
    prob_traj = []
    rewards = []
    with torch.no_grad(): 
        state,h_state,action,next_action,next_state,next_h_state, response,h_response,done=parse_his(his)
        for t in range(len(state)):
            s=np.concatenate([state[t],h_state[t]])
            policy_action=policy.select_action(s)                   
            dist=Normal(policy_action, std)
            
            a=torch.tensor(action[t]).to(device)
            log_prob=dist.log_prob(a).mean(axis=-1) # should be sum, mean for larger value
            prob=torch.exp(log_prob)
            prob_traj.append(prob.cpu().item())
            rewards.append(process_reward_eva(h_response[t]))
        prob_traj=np.array(prob_traj)
        norm_prob=prob_traj/prob_traj.sum()      
        R=(norm_prob*(np.array(rewards))).sum().item()
        return R
        
def stat(returns):
    avg_ret = np.mean(returns)
    se = scipy.stats.sem(returns)
    h = se * sp.stats.t._ppf((1 + 0.95) / 2., len(returns) - 1)
    print(f"mean: {avg_ret}, h: {h}")
    return avg_ret, h

