import numpy as np


def find_differences(a, b, path=""):
    differences = []
    if isinstance(a, dict) and isinstance(b, dict):
        for key in a.keys() | b.keys():
            new_path = f"{path}.{key}" if path else key
            if key not in a:
                differences.append(f"{new_path} missing in first")
            elif key not in b:
                differences.append(f"{new_path} missing in second")
            else:
                differences += find_differences(a[key], b[key], new_path)
    elif isinstance(a, list) and isinstance(b, list):
        for i, (item_a, item_b) in enumerate(zip(a, b)):
            new_path = f"{path}[{i}]"
            differences += find_differences(item_a, item_b, new_path)
        if len(a) != len(b):
            differences.append(f"{path} has different lengths: {len(a)} vs {len(b)}")
    else:
        if a != b:
            differences.append(f"{path} differs: {a} != {b}")
    return differences

def check_spi_associations_next_state_uniqueness(pi_memory, state_pattern, action_pattern, next_state_pattern):

    spi_associations = pi_memory._fetch_spi_associations(state_pattern, action_pattern, next_state_pattern)
    spi_next_map = {(spi_association[3], spi_association[4]): (index, spi_association)
                    for index, spi_association in enumerate(spi_associations)}
    spi_next_set = set(spi_next_map.keys())

    assert (len(sorted([(i[3], i[4]) for i in spi_associations if i[4] is not None], key=lambda x: x[1]))
            == len(sorted([i for i in spi_next_set if i[1] is not None], key=lambda x: x[1])))


def check_memory_entries(v_memory, pi_memory):
    assert np.all(np.array([(k, v_k) for k, v in v_memory.info.items() for v_k in v.keys()]) ==
                  np.array([(k, v_k) for k, v in pi_memory.info.items() for v_k in v.keys()]))

def check_pi_memory_consistency(pi_memory, v_memory):
    for k, v in pi_memory.info.items():
        for v_k, v_v in v.items():
            visited = set()
            visited_loops = set()
            e = v_v

            assert e.data['state_pattern'] == k and e.data['pi_option'] == v_k
            assert not e.next or id(e.next) == id(
                pi_memory.info[e.next.data['state_pattern']][e.next.data['pi_option']])
            assert not e.previous or id(e.previous) == id(
                pi_memory.info[e.previous.data['state_pattern']][e.previous.data['pi_option']])
            for o_k, o_v in e.other_pi_options_previous.items():
                assert not o_v or id(o_v) == id(pi_memory.info[o_v.data['state_pattern']][o_v.data['pi_option']])

            current_state = k
            while e is not None:
                eid = id(e)
                if eid in visited:
                    break
                visited.add(eid)

                if v_memory.info.get(current_state).get(e.data['pi_option']).get('trajectory_info').get('length') == 0:
                    if e.data['pi_option'] in visited_loops:
                        break
                    else:
                        visited_loops.add(e.data['pi_option'])

                assert not e.next or id(e.next) == id(
                    pi_memory.info[e.next.data['state_pattern']][e.next.data['pi_option']])
                assert not e.previous or id(e.previous) == id(
                    pi_memory.info[e.previous.data['state_pattern']][e.previous.data['pi_option']])
                for o_k, o_v in e.other_pi_options_previous.items():
                    assert not o_v or id(o_v) == id(
                        pi_memory.info[o_v.data['state_pattern']][o_v.data['pi_option']])

                if not e.next:
                    break
                e = e.next
                current_state = e.data['state_pattern']

def check_trajectories(pi_memory, v_memory):

    v_data = []
    for k, v in v_memory.info.items():
        for v_k, v_v in v.items():
            stationary_segments = v_v['stationary_segments']
            extended_segments = []
            extended_segments.append(k)
            if v_v['trajectory_info']['length'] != 0:
                extended_segments.extend([v_k] * v_v['trajectory_info']['length'])
            # else:
            #     extended_segments.append(v_k)
            for s in stationary_segments:
                if s.data[1] != 0:
                    extended_segments.extend([s.data[2]] * s.data[1])
                elif s.data[2] != None:
                    extended_segments.append(s.data[2])
            v_data.append(extended_segments)

    pi_data = []
    for k, v in pi_memory.info.items():
        for v_k, v_v in v.items():
            extended_segments = []
            extended_segments.append(k)
            visited = set()
            visited_loops = set()
            e = v_v

            current_state = k
            while True:
                eid = id(e)
                if eid in visited:
                    break
                visited.add(eid)

                if v_memory.info.get(current_state).get(e.data['pi_option']).get('trajectory').get('length') == 0:
                    if e.data['pi_option'] in visited_loops:
                        break
                    else:
                        visited_loops.add(e.data['pi_option'])
                extended_segments.append(e.data['pi_option'])

                if not e.next:
                    break
                e = e.next
                current_state = e.data['state_pattern']

            pi_data.append(extended_segments)

    assert all(a == b for a, b in zip(v_data, pi_data)) and len(v_data) == len(pi_data)


def track_pi_option_entry(v_memory, pi_memory, selected_pi_options, v_debug, pi_debug):
    v_debug_temp = [
        (observation, pi_option, pi_options[pi_option]['action_pattern'],
         tuple(pi_options[pi_option]['cycle_info']['v_value']) if pi_options[pi_option]['cycle_info'] is not None else None,
         tuple(pi_options[pi_option]['trajectory_info']['v_value']) if pi_options[pi_option]['trajectory_info']['v_value'] is not None else None,
         pi_options[pi_option]['trajectory_info']['persistent_length'], pi_options[pi_option]['trajectory_info']['length'],
         tuple(pi_options[pi_option]['cross_policy_value']), pi_options[pi_option]['cross_policy_length'])
        for observation, pi_options in v_memory.info.items()
        for pi_option in pi_options
        if pi_option in selected_pi_options]
    pi_debug_temp = [
        (observation, pi_option, list(pi_options[pi_option].items()))
        for observation, pi_options in pi_memory.info.items()
        for pi_option in pi_options
        if pi_option in selected_pi_options]

    check = False
    if v_debug != v_debug_temp or pi_debug != pi_debug_temp:
        check = True
        v_debug = v_debug_temp
        pi_debug = pi_debug_temp

def check_stationarity(pi_memory):

    for k, v in pi_memory.info.items():
        for v_k, v_v in v.items():
            visited = set()
            e = v_v

            visited_states = set()

            while True:
                eid = id(e)
                if eid in visited:
                    break
                visited.add(eid)

                state_pattern = e.data['state_pattern']
                assert state_pattern not in visited_states
                visited_states.add(state_pattern)

                if not e.next:
                    break
                e = e.next
