from .base import Experiment

import numpy as np
import torch
from alg.buffer import ObjectOrientedBuffer
from alg.model import OOCModel
from core.causal_graph import ObjectOrientedCausalGraph
import alg.model.mask_generator as mg
import alg.functional as F
from core import ParalleledTaskData

class Debug(Experiment):
    use_existing_path = False

    def setup(self, args):
        pass

    def main(self):
        env = self.env
        buf = ObjectOrientedBuffer(50, env.info)
        t = 0

        obs, _ = env.reset()
        while True:
            x = input()
            
            if x == 'q':
                break

            if x == 'r':
                obs, _  = env.reset()
                continue

            a = env.action_space.sample()
            
            print(f"step {t}: action = {a}")
            env.data.print_objects()
            # env.print_objects(indent=1)
            
            obs, r, term, trunc, attrs  = env.step(a)
            buf.add(attrs, obs, r)
            
            print(f"reward = {r}, terminated = {term}")

            if term or trunc:
                obs, _  = env.reset()

            t += 1

        # attr, next_state, mask, reward = buf.sample_batch(10, torch.device('cuda'))
        # net = OOCModel(env.info, OOCModel.Args(), torch.device('cuda'), torch.float)
        # VariableEncoder(env, 32, 16, 0.01, torch.device('cuda'), torch.float)
        #
        # print(attr)
        # print(next_state)
        # print(mask)

        pass
 