import json
import json_repair
import numpy as np


def empty_instance_dict():
    return {
        'timestep': [],
        'translation': [],
        'rotation': [],
        'size': [],
        'attribute_label': []
    }


def fix_dimensions(instance, key, instance_id, historical_data=None):
    """Remove unnecessary dimensions and check correct number of dimensions per key"""
    # Remove unnecessary dimensions (if any)
    num_valid_timesteps = instance['timestep'].squeeze().shape[0] if instance['timestep'].ndim > 1 else instance['timestep'].shape[0]
    instance[key] = instance[key].squeeze()
    if instance[key].ndim == 0:
        instance[key] = np.expand_dims(instance[key], axis=(0,1))
    elif instance[key].ndim == 1:
        if num_valid_timesteps == 1:
            instance[key] = np.expand_dims(instance[key], axis=0)
        else:
            instance[key] = np.expand_dims(instance[key], axis=1)
    if key in ['timestep', 'attribute_label']:
        instance[key] = instance[key].squeeze(1)
    
    if (historical_data is None) or (instance_id not in historical_data) or (historical_data[instance_id] is None):
        return instance, True
    
    # Check correct number of dimensions given the key
    # and assume a static dict (equal to the last historical values) if all translations or rotations don't have three elements or all timesteps don't have one element
    if ((instance['rotation'].shape[-1] != 3) or (instance['translation'].shape[-1] != 3)):
        instance = static_instance_dict(historical_data[instance_id])
        return instance, False

    return instance, True


def convert_to_numpy(scene_dict, history_dict=None):
    correct_format = {}
    for instance_id, instance in scene_dict.items():
        is_instance_correct = True
        for key, value in instance.items():
            try:
                instance[key] = np.array(value)

                # Sanity check the numpy array
                if not any([np.issubdtype(instance[key].dtype, np.integer), 
                            np.issubdtype(instance[key].dtype, np.floating), 
                            np.issubdtype(instance[key].dtype, np.str_)
                            or np.issubdtype(instance[key].dtype, np.unicode_)]):
                    raise ValueError(f'Invalid array for instance {instance_id}, trying to fix it')

            except ValueError as e:
                # If a historical scene dict is not provided, no fix will be attempted and the original error will be raised
                if history_dict is None:
                    raise ValueError(e)

                # If one of the keys of the instance is undreadable, the instance is considered incorrect
                is_instance_correct = False
                
                # Try to fix the entries of the current key
                instance[key] = clean_and_convert(value)
                
                # If it's impossible to fix the key, the instance will be assumed to be static
                if instance[key] is None:
                    instance = static_instance_dict(history_dict[instance_id])
                continue

        scene_dict[instance_id] = instance
        correct_format[instance_id] = is_instance_correct
    return scene_dict, correct_format


def load_instance_data(data_file, numpy=False, historical_data=None):
    # Table to correct possible misspells of key names
    key_conversion = {
        'timesteps': 'timestep',
        'time': 'timestep',
        'times': 'timestep',
        'timestamp': 'timestep',
        'timestamps': 'timestep',
        'translations': 'translation',
        'position': 'translation',
        'positions': 'translation',
        'rotations': 'rotation',
        'angle': 'rotation',
        'angles': 'rotation',
        'orientation': 'rotation',
        'orientations': 'rotation',
        'sizes': 'size',
        'attribute_labels': 'attribute_label',
    }

    with open(data_file, 'r') as f:
        # Try loading and correcting the scene_dict for some common issues, use static scene dict in case of failure
        try:
            data = json_repair.load(f)
            # Correct misspells of key names
            for instance_id, instance in data.items():
                for key in instance:
                    if key in key_conversion:
                        new_key = key_conversion[key]
                        data[instance_id][new_key] = data[instance_id].pop(key)
            if numpy:
                data, correct_format = convert_to_numpy(data, historical_data)
                
                for instance_id, instance in data.items():
                    for key in instance:
                        if instance[key] is None:
                            correct_format[instance_id] = False
                            continue

                        # Remove unnecessary dimensions and check correct number of dimensions per attribute
                        new_instance, correct_format_dim = fix_dimensions(instance, key, instance_id, historical_data)
                        instance[key] = new_instance[key]
                        if not correct_format_dim:
                            correct_format[instance_id] = False

                        # Try to correct instance if the length of all fields isn't equal
                        if len(instance[key]) != len(instance['timestep']):
                            print(f'WARNING: the length of {key} in {instance_id} does not correspond to the number of timesteps, trying to interpolate')
                            instance[key] = resample_array(instance[key], instance['timestep'])
                            correct_format[instance_id] = False
            else:
                correct_format = {}
                for instance_id in data_file:
                    # all instances are considered to be well formatted
                    correct_format[instance_id] = True 
        except:
            # if there is an error, it will be taken care of during the following dictionary_check
            # in the meantime, the format will be considered wrong for all instances
            correct_format = {}
            for instance_id in historical_data:
                correct_format[instance_id] = False
        
        if historical_data is not None:
            data, correct_global = dictionary_check(data, historical_data)
        else:
            correct_global = True
        
        for instance_id in correct_format:
            correct_format[instance_id] = correct_format[instance_id] & correct_global
        if historical_data is not None:
            for instance_id in historical_data:
                if instance_id not in correct_format:
                    correct_format[instance_id] = False

        return data, correct_format
    
    
    return data, True


def instance_data_scored_timesteps(forecasted_data, predicted_data, instance_id, verbose=True):
    instance_gt = forecasted_data[instance_id]
    translations_gt = instance_gt['translation']
    rotations_gt = instance_gt['rotation']
    sizes_gt = instance_gt['size']

    instance_pred = predicted_data[instance_id]

    # Some instances might have missing g.t. annotations for some timesteps, we exclude those from scoring
    # scored_timesteps = np.where(np.isin(instance_pred['timestep'], instance_gt['timestep']))[0]
    scored_timesteps = []
    for val in instance_gt['timestep']:
        matches = np.where(instance_pred['timestep'] == val)[0]
        if matches.size > 0:
            scored_timesteps.append(matches[0])
    scored_timesteps = np.array(scored_timesteps, dtype=int)
    try:
        translations_pred = instance_pred['translation'][scored_timesteps]
    except:
        raise ValueError(instance_pred['translation'].shape, scored_timesteps)
    rotations_pred = instance_pred['rotation'][scored_timesteps]
    try:
        sizes_pred = instance_pred['size'][scored_timesteps]
    except KeyError:
        if verbose:
            print("WARNING: predicted forecast does not contain the 'size' attribute, using last value from history")
        sizes_pred = None

    return translations_gt, rotations_gt, sizes_gt, translations_pred, rotations_pred, sizes_pred, scored_timesteps



def clean_and_convert(list_of_lists):
    cleaned = list_of_lists.copy()

    for i, item in enumerate(list_of_lists):
        if len(item) != 3:
            print(f"WARNING: Inhomogeneous part at index {i} with length {len(item)}. Trying to interpolate based on nearest neighbors.")

            # Try to average neighbors
            try:
                prev = cleaned[i - 1] if i > 0 and len(cleaned[i - 1]) == 3 else None
                next = cleaned[i + 1] if i < len(cleaned) - 1 and len(cleaned[i + 1]) == 3 else None
            except:
                prev = next = None

            if prev is not None and next is not None:
                avg = [(a + b) / 2 for a, b in zip(prev, next)]
                cleaned[i] = avg
            elif prev is not None:
                cleaned[i] = prev.copy()
                print(f"  Replaced with previous element at index {i - 1}.")
            elif next is not None:
                cleaned[i] = next.copy()
                print(f"  Replaced with next element at index {i + 1}.")
            else:
                print(f"Cannot repair inhomogeneous part at index {i}")
                return None

    return np.array(cleaned)


def resample_array(array: np.ndarray, timesteps: np.ndarray) -> np.ndarray:
    n = len(array)
    m = len(timesteps)

    if n == m:
        return array

    # If the array contains strings (e.g., in the case of "attribute_label"), just replicate the first entry
    if np.issubdtype(array.dtype, np.str_):
        return np.array(len(timesteps) * [array[0]])
    
    # Generate normalized indices for interpolation
    original_indices = np.linspace(0, 1, n)
    target_indices = np.linspace(0, 1, m)

    if array.ndim == 1:
        # 1D case
        return np.interp(target_indices, original_indices, array)
    else:
        # Multidimensional case
        return np.stack([
            np.interp(target_indices, original_indices, array[:, dim])
            for dim in range(array.shape[1])
        ], axis=1)


def static_instance_dict(history_dict):
    static_dict = {
        'timestep': np.array([8, 9, 10, 11, 12, 13, 14, 15]),
        'translation': np.array(8*[history_dict['translation'][-1]]),
        'rotation': np.array(8*[history_dict['rotation'][-1]]),
        'size': np.array(8*[history_dict['size'][-1]]),
        'attribute_label': np.array(8*[history_dict['attribute_label'][-1]]),
    }

    return static_dict


def dictionary_check(predicted_data, historical_data):
    """
        Performs final check on predicted scene_dict, 
        returning (original dict, True) if the test is successful, (static_dict, False) otherwise
    """
    # Check that all keys are strings (instance_id) and all values are dictionaries (instance_dicts)
    try:        
        for instance_id, instance_dict in predicted_data.items():
            assert isinstance(instance_id, str)
            assert isinstance(instance_dict, dict)
            
            # Check that all instance_dicts contain the right keys
            assert 'timestep' in instance_dict
            assert 'translation' in instance_dict
            assert 'rotation' in instance_dict
            
            # Check that arrays are actually np.ndarrays
            assert isinstance(instance_dict['timestep'], np.ndarray)
            assert isinstance(instance_dict['translation'], np.ndarray)
            assert isinstance(instance_dict['rotation'], np.ndarray)
            
            # Check that arrays have the right shape
            assert instance_dict['translation'].shape[-1] == 3
            assert instance_dict['rotation'].shape[-1] == 3
            
            # Check that arrays have the right dtype
            assert np.issubdtype(instance_dict['timestep'].dtype, np.integer)
            assert np.issubdtype(instance_dict['translation'].dtype, np.floating)
            assert np.issubdtype(instance_dict['rotation'].dtype, np.floating)
            
    except:
        print('WARNING: Malformed dictionary after all attempted fixes, assuming static dict')
        
        static_scene_dict = {}
        for instance_id, instance_dict in historical_data.items():
            static_scene_dict[instance_id] = static_instance_dict(instance_dict)
        
        return static_scene_dict, False
    
    return predicted_data, True