import torch
import torch.nn.functional as F
from torch import nn
import yaml
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_detect_anomaly(True)



class Buffer:
    def __init__(self, cfg,  env):
        self.cfg = cfg


        self.env = env


    def make_traj(self, policy, sknum, idxset = None):
        state = torch.zeros(sknum, self.cfg.state_dim+self.cfg.index_dim).to(device)
        total_state = torch.zeros(self.cfg.ep_len+1, sknum, self.cfg.state_dim+self.cfg.index_dim).to(device)
        angle_set = torch.zeros(sknum, 1).to(device)

        if idxset == None:
            if self.cfg.sk_num ==16:
                binary_digit_array = [[float(digit) for digit in f"{i:04b}"] for i in range(self.cfg.sk_num)]
            else:
                binary_digit_array = [[float(digit) for digit in f"{i:05b}"] for i in range(self.cfg.sk_num)]
            binary_digit_array = torch.tensor(binary_digit_array).to(device)
            idxset = binary_digit_array * 2 -1

        #print(idx_set)
        state[:, self.cfg.state_dim:self.cfg.state_dim+self.cfg.index_dim] = idxset
    
        state[:, :self.cfg.state_dim] = 0
        total_state[0] = state.clone()
        
        noise_tensor = torch.zeros(self.cfg.ep_len, sknum, 2).to(device)

        for i in range(self.cfg.ep_len):
            
            xyz_action, _, noise = policy(state)

            alpha = self.env.step(state, xyz_action)

            state = torch.cat((alpha, state[:, self.cfg.state_dim:]), dim=1)

            noise_tensor[i] = noise

            total_state[i+1] = state


        return idxset, total_state, noise_tensor


    def make_mul_traj(self, policy, sknum, idxset = None):
        state = torch.zeros(sknum, self.cfg.state_dim+self.cfg.index_dim).to(device)
        total_state = torch.zeros(self.cfg.ep_len+1, sknum, self.cfg.state_dim+self.cfg.index_dim).to(device)
        angle_set = torch.zeros(sknum, 1).to(device)

        if idxset == None:
            if self.cfg.exp ==0:
                binary_digit_array = [[float(digit) for digit in f"{i:04b}"] for i in range(self.cfg.sk_num)]
            else:
                binary_digit_array = [[float(digit) for digit in f"{i:05b}"] for i in range(self.cfg.sk_num)]
            binary_digit_array = torch.tensor(binary_digit_array).to(device)
            idxset = binary_digit_array * 2 -1

        state[:, self.cfg.state_dim:self.cfg.state_dim+self.cfg.index_dim] = idxset
    
        state[:, :self.cfg.state_dim] = 0
        total_state[0] = state.clone()
        
        noise_tensor = torch.zeros(self.cfg.ep_len, sknum, 2).to(device)

        for i in range(self.cfg.ep_len):
            
            xyz_action, _, noise = policy(state)

            alpha = self.env.step(state, xyz_action)

            state = torch.cat((alpha, state[:, self.cfg.state_dim:]), dim=1)

            noise_tensor[i] = noise

            total_state[i+1] = state


        return idxset, total_state, noise_tensor
