import torch
from diffusion_policy.sampler.metric import euclidean_distance, coverage_distance
import torch.nn.functional as F
import pdb
torch.set_printoptions(precision=1, sci_mode=False)


def ema_sampler(policy, prior, obs_dict, beta, weak=None):
    action_dict = policy.predict_action(obs_dict, obs_dict, weak=weak)
    if prior is not None:
        # frame matching
        if policy.oa_step_convention:
            start = policy.n_obs_steps - 1
        else:
            start = policy.n_obs_steps
        end = start + policy.n_action_steps
        assert (action_dict['action'] == action_dict['action_pred'][:,start:end]).all().item()
        # ema update
        CH = prior.shape[1]
        action_dict['action_pred'][:,:CH] = prior * beta + action_dict['action_pred'][:,:CH] * (1. - beta)
        action_dict['action'] = action_dict['action_pred'][:,start:end]
    return action_dict


def ours_sampler(policy, prior, obs_dict, previous_obs_dict, beta, weak=None):
    # action_dict = policy.predict_action(obs_dict, obs_dict, weak=weak)
    action_dict = policy.predict_action(obs_dict, previous_obs_dict, weak = weak)
    
    if prior is not None:
        # frame matching
        start = policy.n_obs_steps - 1
        end = start + policy.n_action_steps
        assert (action_dict['action'] == action_dict['action_pred'][:, start:end]).all().item()
        CH = prior.shape[1] 
        
        new = policy.normalizer['action'].normalize(action_dict['action_pred'][:,:CH])[:,start:end]
        old = policy.normalizer['action'].normalize(prior)[:,start:end]
        cos_sim = F.cosine_similarity(new, old, dim=2, eps=1e-8)
        
        has_negative = (cos_sim < beta).any(dim=1)  # [B]
        mask = ~has_negative
        
        action_dict['action_pred'][:,:CH][mask] = prior[mask]
        action_dict['action'] = action_dict['action_pred'][:,start:end]



    return action_dict

def ac_sampler(policy, prior, obs_dict, previous_obs_dict, beta, weak=None):
    action_dict = policy.predict_action(obs_dict, obs_dict, weak=weak)
    
    if prior is not None:
        # frame matching
        start = policy.n_obs_steps - 1
        end = start + policy.n_action_steps
        assert (action_dict['action'] == action_dict['action_pred'][:, start:end]).all().item()
        CH = prior.shape[1] 
        
        new = policy.normalizer['action'].normalize(action_dict['action_pred'][:,:CH])[:,start:end]
        old = policy.normalizer['action'].normalize(prior)[:,start:end]
        cos_sim = F.cosine_similarity(new, old, dim=2, eps=1e-8)
        
        has_negative = (cos_sim < beta).any(dim=1)  # [B]
        mask = ~has_negative
        
        action_dict['action_pred'][:,:CH][mask] = prior[mask]
        action_dict['action'] = action_dict['action_pred'][:,start:end]

    return action_dict


