from Buffer.all_buffer import AllReplayBuffer
from tianshou.data import Batch
from ActualCausal.Utils.run_dataset import get_proximity
import numpy as np
import copy, time

def get_passive_name(name):
    return "#" + name[1:]

def add_encoding(factored_state, encodings, encoding_length):
    true_next_factored_state = copy.deepcopy(factored_state)
    if encodings is not None:
        default_state = np.zeros(encoding_length)
        for key in factored_state.keys():
            if key in encodings:
                factored_state[key] = encodings[key]
            elif key not in ["Action", "Reward", "Done", "VALID_NAMES", "ITR"]: # TODO: could end up with issues if there are other unencoded states
                factored_state[key] = default_state # needs to fix the length of non-special objects
    return factored_state, true_next_factored_state

def fill_buffer(data, environment, args, extractor, norm, outcome_variable="", encodings=None, encoding_length=-1):
    buffer = AllReplayBuffer(len(data), stack_num=1)
    encodings = [None for _ in range(len(data) - 1)] if encodings is None else encodings
    if encodings is not None: last_encoding = encodings[0] 
    factored_state, true_factored_state = add_encoding(data[0], None if encodings is None else encodings[0], encoding_length)
    use_done = factored_state["Done"]
    last_factored_state, last_true_factored_state = data[0], data[0]
    # all_encodings = np.array(sum([[enc[n] for n in enc.keys()] for enc in encodings], start=list()))
    # print(all_encodings.shape)
    i = 0
    start = time.time()
    max_hit = np.zeros(4)
    for next_factored_state, encoding in zip(data[1:], encodings[1:]):
        # print("adding", i)
        # assign general components
        next_factored_state, true_next_factored_state = add_encoding(next_factored_state, encoding, encoding_length)
        original_nfs, original_fs, original_tfs = next_factored_state, factored_state, true_factored_state
        if environment.name not in ["Breakout", "Phyre", "RobosuitePushing", "AirHockey"]:
            use_done = next_factored_state["Done"]# if predict_dynamics else last_done
        else:
            use_done = factored_state["Done"]
        factored_state["Done"] = np.array([False]) # set the factored state done to false to avoid issues with prediction
        act = next_factored_state["Action"][-1] if environment.discrete_actions else next_factored_state["Action"]
        if args.inter.predict_next_state:
            act = next_factored_state["Action"][-1] if environment.discrete_actions else next_factored_state["Action"]
            factored_state["Action"] = next_factored_state["Action"]
            obs = extractor.get_obs(factored_state)
        else: # we still use buffer.next_state for prediction, but we reassign the current state to remove next state information
            # unshift the actions, the actions are used for the current state evaluation 
            # factored_state["Action"] = next_factored_state["Action"]
            # print(act)
            # print(act, factored_state["Action"], next_factored_state["Action"])
            # specialized logic for DAG-based data, which MUST have an outcome variable (TODO: all prediction not handled)
            # this shifts the actions back, though at the cost of repeating the first state twice
            act = last_factored_state["Action"][-1] if environment.discrete_actions else last_factored_state["Action"]
            
            factored_state, true_factored_state = last_factored_state, last_true_factored_state
            next_factored_state = factored_state # the next state is the current state
            factored_state["Action"] = act 
            factored_state = copy.deepcopy(factored_state)
            # environment.set_from_factored_state(last_factored_state)
            # environment.step(act) # TODO: make this a list instead of a single value
            # print(environment.get_state())
            # we MUST use passive_reassign if none is true. Reassigns the outcome variable to the passive variable values
            if args.inter.passive_reassign:
                for ov in [n for n in extractor.names if n not in ["Action", "Reward", "Goal", "Done"]]:
                    factored_state[ov] = factored_state[get_passive_name(ov)]
            else:
                if type(outcome_variable) == str: factored_state[outcome_variable] = factored_state[outcome_variable] * 0.0 # block out the outcome variable for this step
                else:
                    for ov in outcome_variable:
                        factored_state[ov] = factored_state[ov] * 0.0


            obs = extractor.get_obs(factored_state)
        if np.isnan(np.sum(obs)): raise ValueError("NaN found in state")
        rew = factored_state["Reward"]
        inter = np.ones((extractor.num_objects, extractor.num_objects)) # n x n matrix of interactions
        if "VALID_NAMES" in factored_state:
            valid = np.array(factored_state["VALID_NAMES"])[extractor.kept_nidx] # don't include reward or done in validity vector
        else: valid = np.ones((len(extractor.names)))[:-2]
        
        if args.record.load_trace:
            trace = np.array(next_factored_state["TRACE"]).reshape(len(environment.all_names), len(environment.all_names))
            full_traces = dict()
            for j, name in enumerate(environment.all_names):
                full_traces[name] = trace[j]
        else:
            if args.inter.predict_next_state: full_traces = environment.get_full_trace(true_factored_state, act, outcome_variable=outcome_variable, all_names=extractor.names)
            else: full_traces = environment.get_full_trace(factored_state, act, outcome_variable=outcome_variable, all_names=extractor.names)
            trace = np.stack([full_traces[name] for name in extractor.names], axis=0).astype(float)
            trace = np.pad(trace, (0, 2))[:trace.shape[0]]
        # print("original", factored_state, next_factored_state, trace)
        # environment.set_from_factored_state(factored_state)
        # environment.step(act)
        # print("outcome", environment.get_state()['factored_state'])
        # new_factored = copy.deepcopy(factored_state)
        # new_factored["$B"] = np.random.rand(*factored_state["$B"].shape)
        # new_full_traces = environment.get_full_trace(new_factored, act, outcome_variable=outcome_variable, all_names=extractor.names)
        # new_trace = np.stack([new_full_traces[name] for name in extractor.names], axis=0).astype(float)
        # environment.set_from_factored_state(new_factored)
        # environment.step(act)
        # print("outcome", environment.get_state()['factored_state'], new_trace)
        # print(trace.shape, environment.all_names, extractor.names, extractor.num_objects)
        # print(factored_state, act, args.inter.predict_next_state)
        # if "TRACE" in factored_state: print(trace, factored_state["TRACE"].reshape(len(environment.all_names), len(environment.all_names))[extractor.kept_nidx])
        # trace = np.ones(1)
        target = extractor.get_target(factored_state)
        proximity = get_proximity(environment.pos_size, extractor.target_dim, target, args.state.proximity_epsilon)[0].reshape(-1, extractor.num_objects, extractor.num_objects)
        weight_binary = np.ones(1)
        
        # code for checking object ranges for egregious errors
        # nobs = norm(obs, form="obs")
        # ntarget_diff = norm(extractor.target_selector(next_factored_state) - extractor.target_selector(factored_state), name="all", form="dyn")
        # print(environment.name, use_done, factored_state["Done"])
        # if np.any(np.abs(ntarget_diff) > 1.1) and not use_done:
        #     print(environment.all_names)
        #     print("target_diff", (extractor.target_selector(next_factored_state) - extractor.target_selector(factored_state)).reshape(-1,int(extractor.pad_dim)))
        #     print("ntarget_diff", (ntarget_diff).reshape(-1,int(extractor.pad_dim)))
        # assign selections of the state
        obs = norm(obs, form="obs")
        target = norm(target, name="all")
        next_target = norm(extractor.get_target(next_factored_state), name="all")
        target_diff = extractor.target_selector(next_factored_state) - extractor.target_selector(factored_state)
        # print(factored_state, obs.reshape(-1, 17), target.reshape(10,-1), target_diff.reshape(10,-1), extractor.pad_dim, extractor.get_index(args.inter.train_names[0]), extractor.names)
        # if np.sum(trace.reshape(extractor.num_objects, extractor.num_objects)[7]) > 1: print(np.concatenate([target_diff.reshape(extractor.num_objects, -1), trace.reshape(extractor.num_objects, extractor.num_objects)[7].reshape(extractor.num_objects, 1)], axis=-1), use_done, last_factored_state["Done"], next_factored_state["Done"])
        target_diff = norm(target_diff, name="all", form="dyn")
        obs_next = norm(extractor.get_obs(next_factored_state), form="obs")
        # if trace[2,1] == 1: 
        #     mv = np.abs(target_diff[8:12])
        #     max_hit[max_hit < mv] = mv[max_hit < mv]
        #     print("trace hit", target_diff[8:12], max_hit, use_done)

        # else: print("trace_mis", target_diff[8:12], use_done)
        # print(act, factored_state["Target"], next_factored_state["Target"], use_done, environment.name)
        # print(factored_state["Gripper"], next_factored_state["Gripper"], act, use_done, environment.name)
        info, policy, time_v = dict(), dict(), i
        if i % 1000 == 0:
            print("loaded", i, i / (time.time() - start))
        # print("full", full_traces['$Target'], trace[environment.all_names.index('$Target')], target_diff.reshape(5,-1)[environment.all_names.index('$Target')])
        buffer.add(Batch(obs = obs, obs_next=obs_next, act=act, done=use_done, terminated=False, truncated=use_done, true_done=use_done, rew=rew, true_reward=rew,
            info = info, policy=policy, time = time_v,
            target=target, next_target=next_target, target_diff=target_diff,
            eval_binary = np.ones(trace.shape[0]), confidence = np.ones(trace.shape[0] - 2), norm_confidence = np.ones(1) / len(data), # learned binary and weight, initially filled with ones
            inter=inter, trace=trace, proximity = proximity, weight_binary=weight_binary, passive_weight_binary=weight_binary, valid=valid))
        factored_state, true_factored_state, last_factored_state, last_true_factored_state = original_nfs, true_next_factored_state, original_fs, original_tfs
        i+= 1
    return buffer
