import numpy as np
import torch



def candidate_bids_from_gmm(
    pi, mu, sigma,
    taus=(0.5, 0.75, 1.0, 1.5, 2.0),    
    top_ms=(None, 2, 3),              
    risk_lams=(0.0, 0.25, 0.5),        
    clip=None,                         
    use_tanh=False                     
):
    """
    pi:    [M]            mixture weight
    mu:    [M, act_dim]   component mean
    sigma: [M, act_dim]   component std

    Return:
      Tensor [K, act_dim]
    """
    if pi.dim() > 1:
        pi    = pi.reshape(-1, pi.size(-1))[-1]             
        mu    = mu.reshape(-1, mu.size(-1))[-1].view(-1, mu.size(-1))
        sigma = sigma.reshape(-1, sigma.size(-1))[-1].view(-1, sigma.size(-1))

    eps = 1e-12
    M = pi.numel()
    device, dtype = pi.device, pi.dtype
    act_dim = mu.size(-1)
    sigma_scale = sigma.abs().mean(dim=-1)  
    score = pi / (sigma_scale + 1e-8)       
    order = torch.argsort(score, descending=True)
    cands = []
    for m in top_ms:
        if m is None or m >= M:
            pi_sub, mu_sub, sigma_sub = pi, mu, sigma
        else:
            idx = order[:m]
            pi_sub   = pi[idx]          
            mu_sub   = mu[idx]         
            sigma_sub= sigma[idx]      

        pi_sub = pi_sub / (pi_sub.sum() + eps)   

        for tau in taus:
            pi_tau = torch.softmax(torch.log(pi_sub + eps) / max(tau, 1e-6), dim=-1)  
            # a_mean: [act_dim]
            a_mean = (pi_tau.unsqueeze(-1) * mu_sub).sum(dim=0)
            # second_moment, var_mix, std_mix: [act_dim]
            second_moment = (pi_tau.unsqueeze(-1) * (sigma_sub ** 2 + mu_sub ** 2)).sum(dim=0)
            var_mix = torch.clamp(second_moment - a_mean ** 2, min=0.0)
            std_mix = torch.sqrt(var_mix + 1e-12)
            for lam in risk_lams:
                if lam == 0.0:
                    a = a_mean.clone()
                else:
                    a = a_mean - lam * std_mix  # [act_dim]
                if use_tanh:
                    a = torch.tanh(a)           
                if clip is not None:
                    lo, hi = clip
                    lo = torch.as_tensor(lo, device=a.device, dtype=a.dtype)
                    hi = torch.as_tensor(hi, device=a.device, dtype=a.dtype)
                    a = torch.max(torch.min(a, hi), lo)   
                cands.append(a)

    if len(cands) == 0:
        return torch.empty(0, act_dim, device=device, dtype=dtype)
    cands = torch.stack(cands, dim=0)   # [K0, act_dim]
    baseline = (pi.unsqueeze(-1) * mu).sum(dim=0)  # [act_dim]
    if use_tanh:
        baseline = torch.tanh(baseline)
    if clip is not None:
        lo, hi = clip
        lo = torch.as_tensor(lo, device=a.device, dtype=a.dtype)
        hi = torch.as_tensor(hi, device=a.device, dtype=a.dtype)
        a = torch.max(torch.min(a, hi), lo)  
    all_cands = torch.cat([baseline.unsqueeze(0), cands], dim=0)  
    all_cands_round = all_cands.round(decimals=6)
    all_cands_unique = torch.unique(all_cands_round, dim=0)
    dists = (all_cands_unique - baseline.unsqueeze(0)).norm(dim=-1)  
    order2 = torch.argsort(dists)
    all_cands_sorted = all_cands_unique[order2]  
    return all_cands_sorted

def sample_actions_from_gmm(pi, mu, sigma, num_samples, action_low=None, action_high=None):
    """
    pi:    [M]                mixture weights
    mu:    [M, act_dim]       means
    sigma: [M, act_dim]       std 
    num_samples: sample numbers
    Return: [num_samples, act_dim]
    """
    pi    = pi.float()
    mu    = mu.float()
    sigma = sigma.float()

    M, act_dim = mu.shape
    cat = torch.distributions.Categorical(pi)
    comp_ids = cat.sample((num_samples,))    


    mus   = mu[comp_ids]      
    sigs  = sigma[comp_ids]   
    eps   = torch.randn_like(mus)
    acts  = mus + eps * sigs  
    if action_low is not None and action_high is not None:
        low  = torch.as_tensor(action_low,  device=acts.device, dtype=acts.dtype)
        high = torch.as_tensor(action_high, device=acts.device, dtype=acts.dtype)
        acts = torch.max(torch.min(acts, high), low)

    return acts  



def evaluate_episode_rtg_dist(
        env,
        state_dim,
        act_dim,
        model,
        max_ep_len=1000,
        scale=1000.,
        state_mean=0.,
        state_std=1.,
        device='cuda',
        target_return=None,
        mode='normal',
    ):

    model.eval()
    model.to(device=device)

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)

    state = env.reset()
    if mode == 'noise':
        state = state + np.random.normal(0, 0.1, size=state.shape)

    # we keep all the histories on the device
    # note that the latest action and reward will be "padding"
    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)

    ep_return = target_return
    target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    sim_states = []

    episode_return, episode_length = 0, 0
    for t in range(max_ep_len):

        # add padding
        actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])

        pi,mu,sigam = model.get_action(
            (states.to(dtype=torch.float32) - state_mean) / state_std,
            actions.to(dtype=torch.float32),
            rewards.to(dtype=torch.float32),
            target_return.to(dtype=torch.float32),
            timesteps.to(dtype=torch.long),
        )
        action =  (pi[:, None] * mu).sum(dim=0)  # [3]
        actions[-1] = action
        action = action.detach().cpu().numpy()

        state, reward, done, _ = env.step(action)

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        if mode != 'delayed':
            pred_return = target_return[0,-1] - (reward/scale)
        else:
            pred_return = target_return[0,-1]
        target_return = torch.cat(
            [target_return, pred_return.reshape(1, 1)], dim=1)
        timesteps = torch.cat(
            [timesteps,
             torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1)

        episode_return += reward
        episode_length += 1

        if done:
            break

    norm_score = None
    if hasattr(env, "get_normalized_score"):
        norm_score = env.get_normalized_score(episode_return)*100.0

    return episode_return, episode_length,norm_score



def evaluate_episode_retrieve_cql(
        env,
        state_dim,
        act_dim,
        model,
        q1, q2,                      
        faiss_index=None,           
        faiss_actions=None,         
        faiss_rtgs=None,
        K=20,                      
        num_samples=5,               
        num_retrieved=5,            
        max_ep_len=1000,
        scale=1000.,
        state_mean=0.,
        state_std=1.,
        device='cuda',
        target_return=None,
        mode='normal',
    ):

    model.eval()
    model.to(device=device)
    q1.eval()
    q2.eval()
    q1.to(device)
    q2.to(device)

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)
    if hasattr(env.action_space, "low"):
        act_low  = torch.as_tensor(env.action_space.low,  device=device, dtype=torch.float32)
        act_high = torch.as_tensor(env.action_space.high, device=device, dtype=torch.float32)
    else:
        act_low = act_high = None

    state = env.reset()
    if mode == 'noise':
        state = state + np.random.normal(0, 0.1, size=state.shape)

    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)

    ep_return = target_return
    target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    episode_return, episode_length = 0, 0

    for t in range(max_ep_len):

        actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)], dim=0)

        pi, mu, sigma = model.get_action(
            (states.to(dtype=torch.float32) - state_mean) / state_std,
            actions.to(dtype=torch.float32),
            rewards.to(dtype=torch.float32),
            target_return.to(dtype=torch.float32),
            timesteps.to(dtype=torch.long),
        )
        if pi.dim() > 1:
            pi    = pi[0]
            mu    = mu[0]
            sigma = sigma[0]
        cand_actions_list = []

        cand_actions_gmm = candidate_bids_from_gmm(pi, mu, sigma,taus=(1.0,0.7,1.3,0.8,1.2),top_ms=(5,),risk_lams=(0.2,),clip=(act_low,act_high),use_tanh=False)
        if not isinstance(cand_actions_gmm, torch.Tensor):
            cand_actions_gmm = torch.from_numpy(cand_actions_gmm).float()
        cand_actions_gmm = cand_actions_gmm.to(device=device)

        if act_low is not None:
            cand_actions_gmm = torch.max(torch.min(cand_actions_gmm, act_high), act_low)

        cand_actions_list.append(cand_actions_gmm)
        if (faiss_index is not None) and (faiss_actions is not None):
            T = states.shape[0]
            L = min(K, T)

            s_seq = states[-L:]              
            a_seq = actions[-L:]              
            r_seq = rewards[-L:]              
            rtg_seq = target_return[:, -L:]   
            ts_seq  = timesteps[:, -L:]       
            s_seq = s_seq.unsqueeze(0)                 
            a_seq = a_seq.unsqueeze(0)                  
            r_seq = r_seq.view(1, L, 1)                
            rtg_seq = rtg_seq.view(1, L, 1)            
            ts_seq  = ts_seq.view(1, L)                 
            mask = torch.ones(1, L, device=device)      
            if L < K:
                pad = K - L
                s_pad  = torch.zeros(1, pad, state_dim, device=device)
                a_pad = torch.ones(1, pad, act_dim, device=device) * -10.0
                r_pad  = torch.zeros(1, pad, 1,        device=device)
                rtg_pad= torch.zeros(1, pad, 1,        device=device)
                ts_pad = torch.zeros(1, pad, dtype=torch.long, device=device)
                m_pad  = torch.zeros(1, pad, device=device)
                s_seq  = torch.cat([s_pad,  s_seq ], dim=1)
                a_seq  = torch.cat([a_pad,  a_seq ], dim=1)
                r_seq  = torch.cat([r_pad,  r_seq ], dim=1)
                rtg_seq= torch.cat([rtg_pad,rtg_seq], dim=1)
                ts_seq = torch.cat([ts_pad, ts_seq],  dim=1)
                mask   = torch.cat([m_pad,  mask ],   dim=1)
            else:
                pass

            s_seq  = s_seq.to(dtype=torch.float32)
            a_seq  = a_seq.to(dtype=torch.float32)
            r_seq  = r_seq.to(dtype=torch.float32)
            rtg_seq= rtg_seq.to(dtype=torch.float32)
            mask   = mask.to(dtype=torch.float32)
            
            s_seq_norm   = (s_seq - state_mean) / state_std                    
            with torch.no_grad():
                state_codes = model.forward_codes(
                    states=s_seq_norm,
                    actions=a_seq,
                    rewards=r_seq,
                    returns_to_go=rtg_seq,
                    timesteps=ts_seq,
                    attention_mask=mask,
                ) 

                code_last = state_codes[:, -1, :]  
                code_np = code_last[0].detach().cpu().numpy().astype(np.float32).reshape(1, -1)
 
                norm = np.linalg.norm(code_np, axis=1, keepdims=True)
                norm = np.maximum(norm, 1e-12)             
                code_np = code_np / norm                 

                D, I = faiss_index.search(code_np, num_retrieved*3)  
                retr_rtg_np = faiss_rtgs[I[0]]  
                retr_actions_np = faiss_actions[I[0]]             
                topk = 5
                indices = np.argsort(retr_rtg_np)[::-1][:topk]

                retr_actions_top = retr_actions_np[indices]  
                retr_actions = torch.from_numpy(retr_actions_top).to(device=device, dtype=torch.float32)

                if act_low is not None:
                    retr_actions = torch.max(torch.min(retr_actions, act_high), act_low)

            cand_actions_list.append(retr_actions)


        cand_actions = torch.cat(cand_actions_list, dim=0)  
        cand_round = cand_actions.round(decimals=6)
        cand_unique = torch.unique(cand_round, dim=0)

        K_cand = cand_unique.size(0)

        if K_cand == 0:
            a_mean = (pi.unsqueeze(-1) * mu).sum(dim=0)  
            if act_low is not None:
                a_mean = torch.max(torch.min(a_mean, act_high), act_low)
            best_action = a_mean
        else:
            s_t = torch.from_numpy(state).to(device=device, dtype=torch.float32)
            s_batch = s_t.unsqueeze(0).repeat(K_cand, 1)  
            a_batch = cand_unique.to(device=device)       

            with torch.no_grad():
                q1_vals = q1(s_batch, a_batch)
                q2_vals = q2(s_batch, a_batch)

                if q1_vals.dim() > 1:
                    q1_vals = q1_vals.squeeze(-1)
                if q2_vals.dim() > 1:
                    q2_vals = q2_vals.squeeze(-1)

                q_vals = torch.min(q1_vals, q2_vals)  
                best_idx = torch.argmax(q_vals).item()
                best_action = a_batch[best_idx]       


        actions[-1] = best_action
        action_np = best_action.detach().cpu().numpy()

        state, reward, done, _ = env.step(action_np)

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward
        if mode != 'delayed':
            pred_return = target_return[0, -1] - (reward / scale)
        else:
            pred_return = target_return[0, -1]

        target_return = torch.cat(
            [target_return, pred_return.reshape(1, 1)], dim=1
        )
        timesteps = torch.cat(
            [timesteps,
             torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)],
            dim=1
        )

        episode_return += reward
        episode_length += 1

        if done:
            break

    norm_score = None
    if hasattr(env, "get_normalized_score"):
        norm_score = env.get_normalized_score(episode_return)*100

    return episode_return, episode_length, norm_score


def evaluate_episode_retrieve_iql(
    env,
    state_dim,
    act_dim,
    model,
    qf,                   
    faiss_index,
    faiss_actions,
    faiss_rtgs,
    K,
    num_samples,
    num_retrieved,
    max_ep_len,
    scale,
    target_return,
    mode,
    state_mean,
    state_std,
    device,
):

    model.eval()
    model.to(device=device)
    qf.eval()
    qf.to(device)

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)

    if hasattr(env.action_space, "low"):
        act_low  = torch.as_tensor(env.action_space.low,  device=device, dtype=torch.float32)
        act_high = torch.as_tensor(env.action_space.high, device=device, dtype=torch.float32)
    else:
        act_low = act_high = None

    state = env.reset()
    if mode == 'noise':
        state = state + np.random.normal(0, 0.1, size=state.shape)

    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)

    ep_return = target_return
    target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    episode_return, episode_length = 0, 0

    for t in range(max_ep_len):
        actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)], dim=0)


        pi, mu, sigma = model.get_action(
            (states.to(dtype=torch.float32) - state_mean) / state_std,
            actions.to(dtype=torch.float32),
            rewards.to(dtype=torch.float32),
            target_return.to(dtype=torch.float32),
            timesteps.to(dtype=torch.long),
        )

        if pi.dim() > 1:
            pi    = pi[0]
            mu    = mu[0]
            sigma = sigma[0]

        cand_actions_list = []

        cand_actions_gmm = sample_actions_from_gmm(
            pi, mu, sigma, num_samples,
            action_low=act_low.detach().cpu().numpy() if act_low is not None else None,
            action_high=act_high.detach().cpu().numpy() if act_high is not None else None,
        ) 
        if not isinstance(cand_actions_gmm, torch.Tensor):
            cand_actions_gmm = torch.from_numpy(cand_actions_gmm).float()
        cand_actions_gmm = cand_actions_gmm.to(device=device)

        if act_low is not None:
            cand_actions_gmm = torch.max(torch.min(cand_actions_gmm, act_high), act_low)

        cand_actions_list.append(cand_actions_gmm)

        if (faiss_index is not None) and (faiss_actions is not None):

            T = states.shape[0]
            L = min(K, T)

            s_seq = states[-L:]               
            a_seq = actions[-L:]              
            r_seq = rewards[-L:]              
            rtg_seq = target_return[:, -L:]   
            ts_seq  = timesteps[:, -L:]     


            s_seq = s_seq.unsqueeze(0)                  
            a_seq = a_seq.unsqueeze(0)                
            r_seq = r_seq.view(1, L, 1)              
            rtg_seq = rtg_seq.view(1, L, 1)        
            ts_seq  = ts_seq.view(1, L)            
            mask = torch.ones(1, L, device=device) 
            if L < K:
                pad = K - L
                s_pad  = torch.zeros(1, pad, state_dim, device=device)
                a_pad = torch.ones(1, pad, act_dim, device=device) * -10.0
                r_pad  = torch.zeros(1, pad, 1,        device=device)
                rtg_pad= torch.zeros(1, pad, 1,        device=device)
                ts_pad = torch.zeros(1, pad, dtype=torch.long, device=device)
                m_pad  = torch.zeros(1, pad, device=device)

                s_seq  = torch.cat([s_pad,  s_seq ], dim=1)
                a_seq  = torch.cat([a_pad,  a_seq ], dim=1)
                r_seq  = torch.cat([r_pad,  r_seq ], dim=1)
                rtg_seq= torch.cat([rtg_pad,rtg_seq], dim=1)
                ts_seq = torch.cat([ts_pad, ts_seq],  dim=1)
                mask   = torch.cat([m_pad,  mask ],   dim=1)
            else:
                pass
            s_seq  = s_seq.to(dtype=torch.float32)
            a_seq  = a_seq.to(dtype=torch.float32)
            r_seq  = r_seq.to(dtype=torch.float32)
            rtg_seq= rtg_seq.to(dtype=torch.float32)
            mask   = mask.to(dtype=torch.float32)
            s_seq_norm   = (s_seq - state_mean) / state_std                
            with torch.no_grad():
                state_codes = model.forward_codes(
                    states=s_seq_norm,
                    actions=a_seq,
                    rewards=r_seq,
                    returns_to_go=rtg_seq,
                    timesteps=ts_seq,
                    attention_mask=mask,
                )   

                code_last = state_codes[:, -1, :]   
                code_np = code_last[0].detach().cpu().numpy().astype(np.float32).reshape(1, -1)
       
                norm = np.linalg.norm(code_np, axis=1, keepdims=True)
                norm = np.maximum(norm, 1e-12)            
                code_np = code_np / norm                  

                D, I = faiss_index.search(code_np, num_retrieved*3)  
                retr_rtg_np = faiss_rtgs[I[0]]  
                retr_actions_np = faiss_actions[I[0]]             
                topk = 5
                indices = np.argsort(retr_rtg_np)[::-1][:topk]

                retr_actions_top = retr_actions_np[indices]   
                retr_actions = torch.from_numpy(retr_actions_top).to(device=device, dtype=torch.float32)


                if act_low is not None:
                    retr_actions = torch.max(torch.min(retr_actions, act_high), act_low)

            cand_actions_list.append(retr_actions)

      
        cand_actions = torch.cat(cand_actions_list, dim=0)  
        cand_round = cand_actions.round(decimals=6)
        cand_unique = torch.unique(cand_round, dim=0)

        K_cand = cand_unique.size(0)
        if K_cand == 0:
            a_mean = (pi.unsqueeze(-1) * mu).sum(dim=0)  
            if act_low is not None:
                a_mean = torch.max(torch.min(a_mean, act_high), act_low)
            best_action = a_mean
        else:
            s_t = torch.from_numpy(state).to(device=device, dtype=torch.float32)
            s_batch = s_t.unsqueeze(0).repeat(K_cand, 1)   
            a_batch = cand_unique.to(device=device)        
            with torch.no_grad():
                q_values= qf(s_batch, a_batch)  
                best_idx = torch.argmax(q_values).item()
                best_action = a_batch[best_idx]     

        actions[-1] = best_action
        action_np = best_action.detach().cpu().numpy()

        state, reward, done, _ = env.step(action_np)

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        if mode != 'delayed':
            pred_return = target_return[0, -1] - (reward / scale)
        else:
            pred_return = target_return[0, -1]

        target_return = torch.cat(
            [target_return, pred_return.reshape(1, 1)], dim=1
        )
        timesteps = torch.cat(
            [timesteps,
             torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)],
            dim=1
        )

        episode_return += reward
        episode_length += 1

        if done:
            break

    norm_score = None
    if hasattr(env, "get_normalized_score"):
        norm_score = env.get_normalized_score(episode_return)*100

    return episode_return, episode_length, norm_score

