import numpy as np
import torch
from torch.nn import functional as F  # noqa
def evaluation_cost_tocdt(model,use_weighted_sum,
                    train_type,states, actions,
                    timesteps,attention_mask,):
        
    B,T,_ = actions.shape

    trans_pred_e,_ = model.forward(
        states, actions, timesteps, attention_mask=attention_mask,training=False
    )
    #print(trans_pred_e)
    if use_weighted_sum:
        trans_pred_e = trans_pred_e["weighted_sum"]
    else:
        trans_pred_e = trans_pred_e["value"]

    if train_type == "mean":
        results = torch.mean(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
    elif train_type == "sum":
        results = torch.sum(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
    elif train_type == "last":
        results = trans_pred_e.reshape(B, T)[:, -1].reshape(-1, 1)
    elif train_type == "every":
        results = trans_pred_e.reshape(B, T)
        # # 找到第一个为 1 的索引
        # first_one_index = torch.argmax(attention_mask_e, dim=1)

        # # 选择索引之后的数据
        # results = [row[index:].view(-1).tolist() for row, index in zip(trans_pred_e.reshape(B, T), first_one_index)]
    results = torch.clamp(results,min=0,max=1)
    return results

def evaluation_cost(model,use_weighted_sum,train_type,states_e, actions_e, timesteps_e,attention_mask_e,states_o, actions_o, timesteps_o,attention_mask_o):
        
        B,T,_ = actions_e.shape

        trans_pred_e,_ = model.forward(
            states_e, actions_e, timesteps_e, attention_mask=attention_mask_e,training=False
        )

        trans_pred_o,_ = model.forward(
            states_o, actions_o, timesteps_o, attention_mask=attention_mask_o,training=False
        )

        if use_weighted_sum:
                trans_pred_e = trans_pred_e["weighted_sum"]
                trans_pred_o = trans_pred_o["weighted_sum"]
        else:
                trans_pred_e = trans_pred_e["value"]
                trans_pred_o = trans_pred_o["value"]

        if train_type == "mean":
                sum_pred_e = torch.mean(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_o = torch.mean(trans_pred_o.reshape(B, T), axis=1).reshape(-1, 1)
        elif train_type == "sum":
                sum_pred_e = torch.sum(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_o = torch.sum(trans_pred_o.reshape(B, T), axis=1).reshape(-1, 1)
        elif train_type == "last":
                sum_pred_e = trans_pred_e.reshape(B, T)[:, -1].reshape(-1, 1)
                sum_pred_o = trans_pred_o.reshape(B, T)[:, -1].reshape(-1, 1)
           
        
        balance_ = F.mse_loss(sum_pred_e,sum_pred_o)
        loss = torch.mean(sum_pred_e) - torch.mean(sum_pred_o) + 0.1*balance_
        
        return {
                    f'loss_eval':loss.detach().cpu().item(),
                    f'expert_eval': torch.mean(sum_pred_e).detach().cpu().item(),
                    f'obs_eval': torch.mean(sum_pred_o).detach().cpu().item(),
                    f'balance':balance_.detach().cpu().item(),
                }
        
def my_evaluate_episode_rtg(state_dim,act_dim,model,paths,max_ep_len=1000,scale=1000.,state_mean=0.,state_std=1.,
                        device='cuda',target_return=None,mode='delayed',):
    model.eval()
    model.to(device=device)

    state_mean = torch.from_numpy(state_mean).to(device=device)     
    state_std = torch.from_numpy(state_std).to(device=device)   

    target_r = target_return
    len_paths = len(paths)

    action_agent = []
    action_phy = []
    state_agent = []
    state_phy = []
    die = []
    reward_agent = []
    reward_phy = []


    for i in range(len_paths):
        states = paths[i]['observations']
        next_states = paths[i]['next_observations']
        actions = paths[i]['actions']
        rewards = paths[i]['rewards']

        state = states[0]
        
        actions_dt = torch.zeros((0,act_dim),device=device,dtype=torch.float32)
        states_dt = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
        rewards_dt = torch.zeros(0,device=device,dtype=torch.float32)
        ep_return = target_r
        target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
        timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
        
        sum_agent,sum_phy = 0.,0.

        

        num = min(len(states),max_ep_len)

        for j in range(num):
            # add padding
            actions_dt = torch.cat([actions_dt, torch.zeros((1, act_dim), device=device)], dim=0)
            rewards_dt = torch.cat([rewards_dt, torch.zeros(1, device=device)])

            state_preds,action_preds,return_preds,act_loss = model.get_action(
                (states_dt.to(dtype=torch.float32) - state_mean) / state_std,
                actions_dt.to(dtype=torch.float32),
                rewards_dt.to(dtype=torch.float32),
                target_return.to(dtype=torch.float32),
                timesteps.to(dtype=torch.long),  
            )  # state tensor, action array, return_preds tensor
            # 是否取j
            actionp = action_preds[0, -1]
            actions_dt[-1] = actionp
            returnp = return_preds[0,-1,0]
            rewards_dt[-1] = returnp
            statep = state_preds[0,-1]
            
            

            if mode != 'delayed':
                treturnp = target_return[0,-1] - (returnp/scale)
            else:
                treturnp = target_return[0,-1]
            
            
            target_return = torch.cat([target_return, treturnp.view(1,1)], dim=1)
            timesteps = torch.cat([timesteps,torch.ones((1, 1), device=device, dtype=torch.long) * (j+1)], dim=1)

            sum_agent += returnp.item()
            sum_phy += rewards[j]

            if j != num-1:
                states_dt = torch.cat([states_dt,statep.view(1,state_dim)],dim=0)
            else:
                break

            reward_phy.append(torch.tensor(rewards[j]))
            reward_agent.append(returnp)

            action_phy.append(actions[j])
            action_agent.append(actionp.detach().cpu().numpy())

            state_phy.append(states[j])
            state_agent.append(statep.detach().cpu().numpy())


        
    return reward_agent,reward_phy,state_agent,state_phy,action_agent,action_phy,die,act_loss
