"""
adaptive_ensemble.py

"""

from collections import deque
import numpy as np


class AdaptiveEnsembler:
    def __init__(self, pred_action_horizon, adaptive_ensemble_alpha=0.0, step_interval=5):
        self.pred_action_horizon = pred_action_horizon
        self.step_interval = step_interval
        self.action_history = deque(maxlen=(self.pred_action_horizon//step_interval))
        self.adaptive_ensemble_alpha = adaptive_ensemble_alpha

    def reset(self):
        self.action_history.clear()

    def ensemble_action(self, cur_action):
        # print()
        result = []
        self.action_history.append(cur_action)
        num_actions = len(self.action_history)

        for step_idx in range(self.step_interval):
            
            if cur_action.ndim == 1:
                curr_act_preds = np.stack(self.action_history)
            else:
                indices = [i + step_idx for i in range(self.pred_action_horizon - self.step_interval, -1, -self.step_interval)]
                curr_act_preds_list = []
                for i, pred_actions in zip(indices, self.action_history):
                    curr_act_preds_list.append(pred_actions[i])
                curr_act_preds = np.stack(curr_act_preds_list)
            
            # calculate cosine similarity between the current prediction and all previous predictions
            ref = curr_act_preds[num_actions-1, :]
            previous_pred = curr_act_preds
            dot_product = np.sum(previous_pred * ref, axis=1)  
            norm_previous_pred = np.linalg.norm(previous_pred, axis=1)  
            norm_ref = np.linalg.norm(ref)  
            cos_similarity = dot_product / (norm_previous_pred * norm_ref + 1e-7)

            # compute the weights for each prediction
            weights = np.exp(self.adaptive_ensemble_alpha * cos_similarity)
            weights = weights / weights.sum()

            # compute the weighted average across all predictions for this timestep
            cur_action_new = np.sum(weights[:, None] * curr_act_preds, axis=0)
            result.append(cur_action_new)
            
        return result