import os
import numpy as np

def convert_to_objs(d_s,mapp,n_objs):
    """
    Convert a state-only occupancy d_s to d_objs using mapp.
    """
    return np.bincount(mapp.ravel(), weights=d_s.ravel(), minlength=n_objs)

def convert_to_trajectories(states,S,H):
    """
    Convert the dataset of trajectories (states) obtained from running
    method collect_trajectories in the env with 2, in two single trajectories
    that will be used for TC feedback.
    """
    omega1 = np.zeros((H, S))
    omega2 = np.zeros((H, S))

    for h in range(H):
        omega1[h,states[0,h]] = 1
        omega2[h,states[1,h]] = 1

    return omega1, omega2

def estimate_occupancy(S, A, H, states, actions, sa = False):
    """
    Estimate occupancy measure from trajectories.

    Args:
        S: int, number of states
        A: int, number of actions
        H: int, horizon
        states: np.ndarray (n, H+1)
        actions: np.ndarray (n, H)
        sa: if sa, return d_sa, else d_s

    Returns:
        d_hat: np.ndarray (H, S, A)
    """
    n_trajs = states.shape[0]
    d_hat = np.zeros((H, S, A), dtype=float)

    for h in range(H):
        # indices for all trajectories at step h
        idx = states[:, h] * A + actions[:, h]  # flatten (s,a) → single index
        counts = np.bincount(idx, minlength=S*A).reshape(S, A)
        d_hat[h] = counts

    # normalize
    d_hat /= n_trajs
    
    if sa:
        return d_hat
    else:
        return np.sum(d_hat, axis=2)

def get_feasible_set(
    fD: list = None, 
    fPC: list = None, 
    fTC: list = None,
    n_points: int = 30,
): 
    """
    Generate rewards satisfying the constraints, i.e., generate a list of
    rewards in the feasible set. Recall that r_star[0] = 0.
    """
    feasible_rewards = []
    grid = np.linspace(0, 1, n_points) 
    
    # loop
    count = 0
    for rB in grid:
        for rT in grid:
            for rS in grid:    
                # print number of reward considered
                # if count % 5000 == 0:
                #     print(count) 
                count += 1 
                
                # construct reward
                r = np.array([0,rB,rT,rS])

                # check if the reward satisfies the constraints 
                sat = True 
                
                # D 
                if sat and fD is not None:
                    for (d,M,k) in fD:
                        J = np.dot(r, d)
                        _, J_star = M.compute_optimal_policy_and_performance(r) 
                        if J < J_star - k:
                            sat = False 

                # PC  
                if sat and fPC is not None:
                    for (d1,d2,k) in fPC: 
                        J1 = np.dot(r, d1)
                        J2 = np.dot(r, d2) 
                        if J2 < J1 - k: 
                            sat = False 

                # TC
                if fTC is not None:
                    for (om1,om2,k) in fTC: 
                        G1 = np.dot(r, om1)
                        G2 = np.dot(r, om2) 
                        if G2 < G1 - k:
                            sat = False
                
                # if r in RF, then add the reward values
                if sat:
                    feasible_rewards.append((rB, rT, rS))

    return np.array(feasible_rewards) 

def brute_force(fs, g):
    """
    fs: list of rewards satisfying the constraints.
    g: tuple with the two occupancy measures d1 and d2
    """
    true_M = -np.inf  
    true_r_M = None
    true_m = np.inf   
    true_r_m = None

    for r in fs:
        r = np.concatenate(([0], r))

        # compute delta J(r) 
        delta_J = np.dot(r,g[0]-g[1])

        # update maximum
        if delta_J > true_M: 
            true_M = delta_J 
            true_r_M = np.copy(r)   
        # update minimum
        if delta_J < true_m: 
            true_m = delta_J 
            true_r_m = np.copy(r)
    
    return true_M, true_r_M, true_m, true_r_m


def save_results(folder, results):
    """
    Save dict to file.
    """
    os.makedirs(folder, exist_ok = True)
    for k,v in results.items():
        np.save(folder+k,v)

def load_results(folder):
    """
    Load dumped results.
    """
    results = {}
    for file in os.listdir(folder):
        results[file[:-4]] = np.load(folder+file)
    return results

def print_results(results):
    """
    Print dict.
    """
    for k,v in results.items():
        v = np.array(v)
        print('-'*10,k)
        err_x = np.abs(v[:,0]-v[:,2])
        mx, stdx = np.mean(err_x), np.std(err_x)
        err_I = np.abs(v[:,1]-v[:,3])
        mI, stdI = np.mean(err_I), np.std(err_I)
        print(f"mean abs error: x-x_hat={mx:.2f}",u"\u00B1",f"{stdx:.2f}")
        print(f"mean abs error: I-I_hat={mI:.2f}",u"\u00B1",f"{stdI:.2f}")

def print_results2(results):
    """
    Print dict.
    """
    for k,v in results.items():
        v = np.array(v)
        print('-'*10,k)
        err_x = np.abs(v[:,0]-v[:,1])
        mx, stdx = np.mean(err_x), np.std(err_x)
        print(f"mean abs error: delta J-x_hat={mx:.2f}",u"\u00B1",f"{stdx:.2f}")