import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
from bidding_train_env.baseline.dt_dist.utils import IndexReplayBuffer
from bidding_train_env.baseline.dt_dist.dt_embedding import EmbeddingTransformer
from torch.utils.data import DataLoader
import logging
import pickle
import torch
import pandas as pd
import faiss


logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] [%(name)s] [%(filename)s(%(lineno)d)] [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def run_code():
    get_code()


def get_code():
    state_dim = 16
    current_dir = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(current_dir, "../../data/trajectory/trajectory_data.csv")

  
    replay_buffer = IndexReplayBuffer(16, 1, data_path)
   

    file_name = os.path.dirname(os.path.realpath(__file__))
    dir_name = os.path.dirname(file_name)
    model_path = os.path.join(dir_name, "saved_model", "DT_embedding", "dt.pt")
    picklePath = os.path.join(dir_name, "saved_model", "DT_embedding", "normalize_dict.pkl")
    with open(picklePath, 'rb') as f:
        normalize_dict = pickle.load(f)
    model = EmbeddingTransformer(state_dim=state_dim, act_dim=1, state_mean=normalize_dict["state_mean"],
                                     state_std=normalize_dict["state_std"])
    model.load_net(model_path)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    model.to(device)

    dataloader = DataLoader(replay_buffer,batch_size=1)



    all_encodings = []
    all_next_actions = []
    all_retrieve_rtgs = []
    all_trajectory_indices = []  # save trajectoyr index
    trajectories = [] # save original trajectory
    all_position_indices = []
    count = 0
    
   

    with torch.no_grad():
        for states, actions, rewards, dones, rtg, timesteps, attention_mask in dataloader:
            states=states.to(device)
            actions=actions.to(device)
            rewards = rewards.to(device)
            dones = dones.to(device)
            rtg = rtg[:, :-1].to(device)
            timesteps = timesteps.to(device)
            attention_mask = attention_mask.to(device)
        

            _, _, _, _,state_encodings=model.forward(states,actions,rewards,rtg,timesteps,attention_mask)
            

            seq_len = states.shape[1]
            for j in range(seq_len):
                encodings = state_encodings[0,j].cpu().numpy()
                next_actions = actions[0,j].cpu().numpy()
                retrieve_rtg = rtg[0,j].cpu().numpy()

                if retrieve_rtg !=0:

                    all_encodings.append(encodings)
                    all_next_actions.append(next_actions)
                    all_retrieve_rtgs.append(retrieve_rtg)
                    all_trajectory_indices.append(count) 
                    all_position_indices.append(j)

            trajectory_data = {
                'states': states.cpu().numpy(),
                'actions': actions.cpu().numpy(),
                'rewards': rewards.cpu().numpy(),
                'dones': dones.cpu().numpy(),
                'rtg': rtg.cpu().numpy(),
                'timesteps': timesteps.cpu().numpy(),
                'attention_mask': attention_mask.cpu().numpy()
            }
            trajectories.append(trajectory_data)
            count += 1
            if (count) % 100 == 0:
                logger.info(f"Processed {count} trajectories, collected {len(all_encodings)} samples")
            
            
        

        all_encodings = np.array(all_encodings)
            
        all_next_actions = np.array(all_next_actions)
        all_retrieve_rtgs = np.array(all_retrieve_rtgs)
        all_trajectory_indices = np.array(all_trajectory_indices)
        all_position_indices = np.array(all_position_indices)

        # L2 normalization
        norms = np.linalg.norm(all_encodings, axis=1, keepdims=True)
        all_encodings_normalized = all_encodings / norms

        # create faiss index
        encoding_dim = all_encodings_normalized.shape[1]
        measure = faiss.METRIC_INNER_PRODUCT
        param = 'HNSW64'
        index = faiss.index_factory(encoding_dim,param,measure)  # inner product
        print(index.is_trained)

        index.add(all_encodings_normalized.astype(np.float32))  # float32

        save_dir = os.path.join(dir_name, "encodings_auction")
        os.makedirs(save_dir, exist_ok=True)

        faiss_path = os.path.join(save_dir, "encodings.index")
        faiss.write_index(index, faiss_path)
        logger.info(f"Created and saved Faiss index to {faiss_path}")
    

        np.save(os.path.join(save_dir, "next_actions.npy"), all_next_actions)
        np.save(os.path.join(save_dir, "retrieve_rtgs.npy"), all_retrieve_rtgs)
        np.save(os.path.join(save_dir, "trajectory_indices.npy"), all_trajectory_indices)
        np.save(os.path.join(save_dir, "position_indices.npy"), all_position_indices)
        with open(os.path.join(save_dir, "trajectories.pkl"), 'wb') as f:
            pickle.dump(trajectories, f)

        logger.info(f"Saved {len(all_encodings)} encodings and next actions to {save_dir}")
        logger.info(f"Encodings shape: {all_encodings.shape}, Next actions shape: {all_next_actions.shape}")
        logger.info(f"Also saved trajectory indices, position indices, and {len(trajectories)} complete trajectories")

     



if __name__ == "__main__":
    run_code()
