import os
import torch
import argparse
import lang_hrl
from lang_hrl.utils.trainer import Config, load

parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, help="path to saved inverse model")
parser.add_argument("--dataset", type=str, help="path to dataset")
parser.add_argument("--device", default="cuda", help="device")
parser.add_argument("--batch-size", default=64, type=int)

args = parser.parse_args()

config = Config.load(os.path.join(args.path, "config.yaml"))
model_path = os.path.join(args.path, "best_model.pt")
model, env = load(config, model_path, device=args.device)

device = torch.device(args.device)

# Load the dataset
if isinstance(env.env, lang_hrl.envs.FullyObsLanguageWrapper):
    from lang_hrl.datasets.datasets import BehaviorCloningDataset
    collate_fn = None
    dataset = BehaviorCloningDataset.load(args.dataset).to_tensor_dataset()
    is_temporal = False
    
elif isinstance(env.env, lang_hrl.envs.LanguageWrapper):
    from lang_hrl.datasets.datasets import BabyAITrajectoryDataset
    collate_fn = lang_hrl.datasets.datasets.traj_collate_fn
    dataset = BabyAITrajectoryDataset.load(args.dataset)
    is_temporal = True

# elif isinstance(self.env.unwrapped, lang_hrl.envs.mazebase.MazebaseGame):
#     dataset = crafting_dataset.CraftingDataset(self.dataset, self.vocab) # Must have created the vocab. Note that it was already given to the agent.
#     assert self.validation_dataset is None, "MazeBase does not have a validation dataset."
#     collate_fn = partial(crafting_dataset.collate_fn, vocab_size=len(self.vocab))

def flatten_seq_dim(tensor):
    b, s = tensor.shape[0], tensor.shape[1]
    rest_of_dims = tensor.shape[2:]
    return tensor.reshape(b*s, *rest_of_dims)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)

num_correct = 0
num_preds = 0
with torch.no_grad():
    for i, (obs, actions) in enumerate(dataloader):
        if isinstance(obs, dict):
            obs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in obs.items() }
        else:
            obs = obs.to(device)
        actions = actions.to(device).long() # Must convert to long

        if is_temporal:
            # Move tensors to GPU
            if isinstance(obs, dict):
                # obs_before = {k: flatten_seq_dim(v[:, :-1]) if isinstance(v, torch.Tensor) else v for k, v in obs.items()}
                # obs_after = {k: flatten_seq_dim(v[:, 1:]) if isinstance(v, torch.Tensor) else v for k, v in obs.items()}
                # else:
                obs = obs['image']
            obs_before = flatten_seq_dim(obs[:, :-1])
            obs_after = flatten_seq_dim(obs[:, 1:])
            actions = flatten_seq_dim(actions[:, :-1])

        else:
            obs_before = {k: v[:-1, :] if isinstance(v, torch.Tensor) else v for k, v in obs.items()}
            obs_after = {k: v[1:, :] if isinstance(v, torch.Tensor) else v for k, v in obs.items()}
            actions = actions[:-1]

        # Now actually get the predictions
        pred = model.predict(obs_before, obs_after, batched=True, is_tensor=True)
        num_correct += (pred == actions).sum().item()
        num_preds += (actions != -100).sum().item()

        if (i + 1) % 100 == 0:
            print("Finished", i+1, "batches:", num_correct / num_preds)

print("FINAL ACCURACY:", num_correct / num_preds)

    


