from utils import *
import torch
from model import *
from torch.utils.data import DataLoader
import itertools

def BC_Trainer(env, data, Encode_type):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    obs_encode, acts_encode = Input_Encode(Encode_type, env.state_size, env.action_size)
    obs_encode, acts_encode = obs_encode.to(device), acts_encode.to(device)
    
    model = BC(state_dim=obs_encode.shape[1], action_dim=acts_encode.shape[1], max_action=env.action_size-1, a_hidden_sizes=[64, 64], device=device)
    
    trainloader = itertools.cycle(DataLoader(data, batch_size=32, shuffle=True))
    trainloader_iter = iter(trainloader)
    
    epochs = 80000
    
    for epoch in range(epochs):
        batch = next(trainloader_iter)
        observations, actions, _, _, _, _, _, _ = [b.to(device) for b in batch]

        model.setup_optimizers(actor_lr=0.001)
        loss, _ = model.actor_loss(observations, actions)

        if (epoch+1) % 10000 == 0:
            print(f"Epoch: {epoch}, Actor Loss: {loss}")
            
    return model
