import torch
from tensordict import TensorDict


class ReplayBuffer:

    def __init__(self, expert_init_states, expert_init_obs, expert_transition, union_init_states, union_init_obs, union_transition):
        self.expert_init_states = expert_init_states
        self.expert_init_obs = expert_init_obs
        self.expert_transition = expert_transition

        # self.imperfect_init_states = imperfect_init_states
        # self.imperfect_init_obs = imperfect_init_obs
        # self.imperfect_transition = imperfect_transition

        # self.union_init_states = torch.cat([expert_init_states, imperfect_init_states], dim=0)
        # self.union_init_obs = torch.cat([expert_init_obs, imperfect_init_obs], dim=0)
        # self.union_transition = torch.cat([expert_transition, imperfect_transition], dim=0)

        self.union_init_states = union_init_states
        self.union_init_obs = union_init_obs
        self.union_transition = union_transition

        union_states = union_transition["states"]
        self.state_shift = -torch.mean(union_states, 0)
        self.state_scale = 1.0 / (torch.std(union_states, 0) + 1e-3)
        self.union_init_states = (self.union_init_states + self.state_shift) * self.state_scale
        for field in ["states", "next_states"]:
            self.expert_transition[field] = (self.expert_transition[field] + self.state_shift) * self.state_scale
            self.union_transition[field] = (self.union_transition[field] + self.state_shift) * self.state_scale
        
        union_obs = union_transition["obs"]
        self.obs_shift = -torch.mean(union_obs, 0)
        self.obs_scale = 1.0 / (torch.std(union_obs, 0) + 1e-3)
        self.union_init_obs = (self.union_init_obs + self.obs_shift) * self.obs_scale
        for field in ["obs", "next_obs"]:
            self.expert_transition[field] = (self.expert_transition[field] + self.obs_shift) * self.obs_scale
            self.union_transition[field] = (self.union_transition[field] + self.obs_shift) * self.obs_scale

    def sample(self, n_minibatches=32, device="cpu"):
        batch_size = self.union_size // n_minibatches

        for union_indices in torch.randperm(self.union_size).split(batch_size):
            if len(union_indices) < batch_size:
                continue
            union_init_indices = torch.randperm(self.init_size)[:batch_size]
            expert_indices = torch.randperm(self.expert_size)[:batch_size]

            union_init_states = self.union_init_states[union_init_indices].to(device)
            union_init_obs = self.union_init_obs[union_init_indices].to(device)
            expert_transition = self.expert_transition[expert_indices].to(device)
            union_transition = self.union_transition[union_indices].to(device)

            yield union_init_states, union_init_obs, expert_transition, union_transition

    @property
    def init_size(self):
        return self.union_init_states.size(0)
    
    @property
    def expert_size(self):
        return self.expert_transition.size(0)
    
    @property
    def union_size(self):
        return self.union_transition.size(0)
    
    @property
    def imperfect_size(self):
        return self.union_size - self.expert_size
    
    @property
    def st_dim(self):
        return self.expert_transition["states"].size(-1)

    @property
    def ob_dim(self):
        return self.expert_transition["obs"].size(-1)
    
    @property
    def ac_dim(self):
        return self.expert_transition["avails"].size(-1)
    
    @property
    def n_agents(self):
        return self.expert_transition["obs"].size(-2)
    
    
    @staticmethod
    def flatten_data(states, obs, avails, actions, next_states, next_obs, next_avails, dones, actives):
        states, obs, avails, actions, next_states, next_obs, next_avails, dones, actives = map(lambda x: x.flatten(0, 1), [states, obs, avails, actions, next_states, next_obs, next_avails, dones, actives])

        states = states[actives]
        obs = obs[actives]
        avails = avails[actives]
        actions = actions[actives]
        next_states = next_states[actives]
        next_obs = next_obs[actives]
        next_avails = next_avails[actives]
        dones = dones[actives]

        batch_size = states.size(0)
        transition = TensorDict({
            "states": states,
            "obs": obs,
            "avails": avails,
            "actions": actions,
            "next_states": next_states,
            "next_obs": next_obs,
            "next_avails": next_avails,
            "dones": dones,
        }, batch_size=batch_size)
        return transition
    
    @staticmethod
    def preprocess_data(dataset, exsize=200):
        all_states = dataset["states"]
        all_obs = dataset["obs"]
        all_avails = dataset["avails"]

        union_init_states = all_states[:, 0]
        union_init_obs = all_obs[:, 0]
        union_states = all_states[:, :-1] # .flatten(0, 1)
        union_obs = all_obs[:, :-1]   # .flatten(0, 1)
        union_avails = all_avails[:, :-1] # .flatten(0, 1)
        union_actives = dataset["actives"]    # .flatten(0, 1)
        union_actions = dataset["actions"]    # .flatten(0, 1)
        union_next_states = all_states[:, 1:] # .flatten(0, 1)
        union_next_obs = all_obs[:, 1:]   # .flatten(0, 1)
        union_next_avails = all_avails[:, 1:] # .flatten(0, 1)
        union_dones = dataset["dones"]    # .flatten(0, 1)

        expert_init_states = union_init_states[:exsize]
        expert_init_obs = union_init_obs[:exsize]
        expert_obs = union_obs[:exsize]
        expert_states = union_states[:exsize]
        expert_avails = union_avails[:exsize]
        expert_actions = union_actions[:exsize]
        expert_next_states = union_next_states[:exsize]
        expert_next_obs = union_next_obs[:exsize]
        expert_next_avails = union_next_avails[:exsize]
        expert_dones = union_dones[:exsize]
        expert_actives = union_actives[:exsize]

        expert_transition = ReplayBuffer.flatten_data(expert_states, expert_obs, expert_avails, expert_actions, expert_next_states, expert_next_obs, expert_next_avails, expert_dones, expert_actives)
        union_transition = ReplayBuffer.flatten_data(union_states, union_obs, union_avails, union_actions, union_next_states, union_next_obs, union_next_avails, union_dones, union_actives)
        
        return expert_init_states, expert_init_obs, expert_transition, union_init_states, union_init_obs, union_transition


    @classmethod
    def from_h5py(cls, env_name, exsize, use_llm=False, folder="dataset"):
        import h5py

        if use_llm:
            file_path = f"{folder}/{env_name}_llm.h5"
        else:
            file_path = f"{folder}/{env_name}.h5"
        with h5py.File(file_path, "r") as f:
            dataset = {k: torch.from_numpy(v[:]) for k, v in f.items()}
        
        print("dataset:")
        for k, v in dataset.items():
            print(f" - {k}: {v.shape}")

        expert_init_states, expert_init_obs, expert_transition, union_init_states, union_init_obs, union_transition = ReplayBuffer.preprocess_data(dataset, exsize=exsize)
        return cls(expert_init_states, expert_init_obs, expert_transition, union_init_states, union_init_obs, union_transition)