from Algorithms.ETD.config_etd import ETDConfig
from Common.runner import Worker

from Common.utils import *




class ETDWorker(Worker):
    config:ETDConfig

    @staticmethod
    def mrn_distance_np(x,y):
        eps = 1e-8
        d = x.shape[-1]
        x_prefix = x[..., :d // 2]
        x_suffix = x[..., d // 2:]
        y_prefix = y[..., :d // 2]
        y_suffix = y[..., d // 2:]
        max_component = np.max(np.clip(x_prefix - y_prefix,0,1000), axis=-1)
        l2_component = np.sqrt(np.square(x_suffix - y_suffix).sum(-1) + eps)
        return max_component + l2_component

    def reset(self):
        super().reset()
        self.etd_emb_buffer=[]

    def step_etd(self,emb_s,emb_s_):
        if not self.etd_emb_buffer:
            self.etd_emb_buffer.append(emb_s)
        embs_prev=np.stack(self.etd_emb_buffer)
        embs4calc=emb_s_[None,:]
        etd_dist=self.mrn_distance_np(embs_prev,embs4calc)
        etd_reward=np.min(etd_dist)
        self.etd_emb_buffer.append(emb_s_)
        return etd_reward,{
            "etd_dist_all":etd_dist,
        }

    def update_erir(self):
        return 1,1

    def step_extra(self,conn):
        emb_s, emb_s_ = conn.recv()
        etd_reward, etd_info = self.step_etd(emb_s,emb_s_)
        conn.send(etd_reward)
