import sys
import numpy as np
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from bidding_train_env.common.utils import save_normalize_dict
from bidding_train_env.baseline.dt_dist.utils import EpisodeReplayBuffer
from bidding_train_env.baseline.dt_dist.dt_embedding import EmbeddingTransformer
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import logging



logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] [%(name)s] [%(filename)s(%(lineno)d)] [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)



def run_dt():
    train_model()


def train_model():
    state_dim = 16

    current_dir = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(current_dir, "../../data/trajectory/trajectory_data.csv")
    replay_buffer = EpisodeReplayBuffer(16, 1, data_path)
   
    save_normalize_dict({"state_mean": replay_buffer.state_mean, "state_std": replay_buffer.state_std},
                         "saved_model/DT_embedding")
    logger.info(f"Replay buffer size: {len(replay_buffer.trajectories)}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    model = EmbeddingTransformer(state_dim=state_dim, act_dim=1, state_mean=replay_buffer.state_mean,
                              state_std=replay_buffer.state_std)
    model.to(device)

    # step_num = 400000
    step_num = 10
    batch_size = 128
    save_interval = 2
    sampler = WeightedRandomSampler(replay_buffer.p_sample, num_samples=step_num * batch_size, replacement=True)
    dataloader = DataLoader(replay_buffer, sampler=sampler, batch_size=batch_size)
    losses=[]
    steps=[]

    model.train()
    i = 0
    for states, actions, rewards, dones, rtg, timesteps, attention_mask, ctg, score_to_go, costs in dataloader:
        states=states.to(device)
        actions=actions.to(device)
        rewards = rewards.to(device)
        dones = dones.to(device)
        rtg = rtg.to(device)
        timesteps = timesteps.to(device)
        attention_mask = attention_mask.to(device)
        
        train_loss = model.step(states, actions, rewards, dones, rtg, timesteps, attention_mask)
        i += 1

        loss_value = np.mean(train_loss)
        losses.append(loss_value)
        steps.append(i)

        logger.info(f"Step: {i} Action loss: {np.mean(train_loss)}")
        model.scheduler.step()

        if i % save_interval == 0 and i != step_num:
            model.save_net("saved_model/DT_embedding",i)
            logger.info(f"Model saved at step {i}")
    model.save_net("saved_model/DT_embedding")
    test_state = np.ones(state_dim, dtype=np.float32)
    
    logger.info(f"Test action: {model.take_actions(test_state)}")




if __name__ == "__main__":
    run_dt()
