#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import numpy as np
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from torch.nn.utils.rnn import pad_sequence
import torch

class Buffer:
    """
    A buffer for storing trajectory data and calculating returns for the policy
    and critic updates.
    """
    def __init__(self, gamma=0.99, lam=0.95, device='cuda', use_lstm = False):
        self.gamma = gamma
        self.lam = lam
        self.device = device
        self.use_lstm = use_lstm
        self.clear()  # Initialize buffers as empty dictionaries

    def __len__(self):
        return self.ptr
    
    def clear(self):
        self.obs = {'glyphs': [],
                    'blstats': [],
                    'tty_chars': [],
                    'inv_strs': [],
                    'inv_letters': [],
                    'tty_cursor': [],
                    }  # Initialize with keys relevant to your data
        self.next_obs = {'glyphs': [],
                    'blstats': [],
                    'tty_chars': [],
                    'inv_strs': [],
                    'inv_letters': [],
                    'tty_cursor': [],
                    }
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.logits = []
        self.controller_prob = []
        self.meta_controller_probs = []
        self.meta_controller_values = []
        self.meta_controller_tensor = []
        self.done = []
        self.core_state = []

        self.ptr = 0
        self.traj_idx = [0]
        self.returns = []
        self.ep_returns = []  # for logging
        self.ep_lens = []

    def store(self, state, next_state, action, reward, value, log_probs, logits, controller_prob, meta_controller_probs, meta_controller_values, meta_controller_tensor, done, core_state):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        for key in state:
            if key not in self.obs:
                self.obs[key] = []
            if key not in self.next_obs:
                self.next_obs[key] = []
            self.obs[key].append(state[key])
            self.next_obs[key].append(next_state[key])
        
        self.actions.append(action.squeeze())
        self.rewards.append(reward.squeeze())
        self.values.append(value.squeeze())
        self.log_probs.append(log_probs.squeeze())
        self.logits.append(logits.squeeze())
        self.controller_prob.append(controller_prob)
        self.meta_controller_probs.append(meta_controller_probs)
        self.meta_controller_values.append(meta_controller_values)
        self.meta_controller_tensor.append(meta_controller_tensor)
        self.done.append(done)
        if self.use_lstm :
            core = (core_state[0].to("cpu").detach().numpy(), core_state[1].to("cpu").detach().numpy())
            self.core_state.append(core)
        self.ptr += 1

    def finish_path(self, last_val=None):
        self.traj_idx.append(self.ptr)
        rewards = self.rewards[self.traj_idx[-2]:self.traj_idx[-1]]

        returns = []
        R = last_val
        for reward in reversed(rewards):
            R = self.gamma * R + reward
            returns.insert(0, R) 


        self.returns.extend(returns)
        self.ep_returns.append(np.sum(rewards))
        self.ep_lens.append(len(rewards))

    def convert_tensor_to_numpy(self, tensor):
        if tensor.numel() == 1:  # 如果张量中只有一个元素
            return tensor.item()  # 返回标量
        else:
            return tensor.cpu().numpy()

    def get(self):
        if self.use_lstm :
            return (
                {key: np.array(val) for key, val in self.obs.items()},
                {key: np.array(val) for key, val in self.next_obs.items()},
                np.array(self.actions),
                np.array(self.returns),
                np.array(self.values),
                np.array(self.log_probs),
                np.array(self.logits),
                np.array(self.controller_prob),
                np.array(self.meta_controller_probs),
                np.array(self.meta_controller_values),
                np.array(self.meta_controller_tensor),
                np.array(self.done),
                np.array(self.core_state)
            )
        else :
            def convert_tensor_to_numpy(tensor):
                """
                将PyTorch张量转换为NumPy数组或标量。
                对单元素张量使用.item()来转换为标量，否则使用.numpy()。
                """
                if isinstance(tensor, torch.Tensor):
                    if tensor.numel() == 1:  # 如果张量中只有一个元素
                        return tensor.item()  # 返回标量
                    else:
                        return tensor.cpu().numpy()  # 多元素张量转换为NumPy数组
                return np.array(tensor)  # 如果已经是NumPy数组或其他类型

            # 处理obs和next_obs
            for key, val in self.obs.items():
                for i in range(len(val)):
                    self.obs[key][i] = self.obs[key][i].cpu()

            for key, val in self.next_obs.items():
                for i in range(len(val)):
                    self.next_obs[key][i] = self.next_obs[key][i].cpu()

            # 确保所有obs的键值对能够正确转换为NumPy数组或标量
            converted_obs = {}
            for key, val in self.obs.items():
                # 如果当前键对应的值是列表，处理列表中的每个张量
                if isinstance(val, list):
                    converted_obs[key] = np.array([convert_tensor_to_numpy(v) for v in val])
                    # 特殊处理：如果列表只有一个元素，确保它不被转换为标量
                    if len(converted_obs[key]) == 1:
                        converted_obs[key] = np.expand_dims(converted_obs[key], axis=0)
                else:
                    converted_obs[key] = convert_tensor_to_numpy(val)

            # 同样的处理用于next_obs
            converted_next_obs = {}
            for key, val in self.next_obs.items():
                if isinstance(val, list):
                    converted_next_obs[key] = np.array([convert_tensor_to_numpy(v) for v in val])
                    if len(converted_next_obs[key]) == 1:
                        converted_next_obs[key] = np.expand_dims(converted_next_obs[key], axis=0)
                else:
                    converted_next_obs[key] = convert_tensor_to_numpy(val)

            # 返回处理后的obs和next_obs
            return (
                converted_obs,
                converted_next_obs,
                np.array(self.actions),
                np.array(self.returns),
                np.array(self.values),
                np.array(self.log_probs),
                np.array(self.logits),
                np.array(self.controller_prob),
                np.array(self.meta_controller_probs),
                np.array(self.meta_controller_values),
                np.array(self.meta_controller_tensor),
                np.array(self.done),
            )
            # for key, val in self.obs.items() :
            #     for i in range(len(val)) :
            #         self.obs[key][i] = self.obs[key][i].cpu()
            # for key, val in self.next_obs.items() :
            #     for i in range(len(val)) :
            #         self.next_obs[key][i] = self.next_obs[key][i].cpu()

            # return (
            #     {key: np.array(val) for key, val in self.obs.items()},
            #     {key: np.array(val) for key, val in self.next_obs.items()},
            #     np.array(self.actions),
            #     np.array(self.returns),
            #     np.array(self.values),
            #     np.array(self.log_probs),
            #     np.array(self.logits),
            #     np.array(self.controller_prob),
            #     np.array(self.meta_controller_probs),
            #     np.array(self.meta_controller_values),
            #     np.array(self.meta_controller_tensor),
            #     np.array(self.done),
            # )
            

    def sample(self, batch_size=64, recurrent=False):
        if recurrent:
            random_indices = np.random.permutation(len(self.ep_lens))
            last_index = random_indices[-1]
            sampler = []
            indices = []
            num_sample = 0
            for i in random_indices:
                indices.append(i)
                num_sample += self.ep_lens[i]
                if num_sample > batch_size or i == last_index:
                    sampler.append(indices)
                    indices = []
                    num_sample = 0
        else:
            random_indices = SubsetRandomSampler(range(self.ptr))
            sampler = BatchSampler(random_indices, batch_size, drop_last=True)

        
        if self.use_lstm :
            obs_dict, next_obs_dict, actions, returns, values, log_probs, logits, controller_prob, meta_controller_probs, meta_controller_values, meta_controller_tensor, done, core_state = self.get()
        else :
            core_state = []
            obs_dict, next_obs_dict, actions, returns, values, log_probs, logits, controller_prob, meta_controller_probs, meta_controller_values, meta_controller_tensor, done = self.get()
        
        actions = torch.tensor(actions)
        returns = torch.tensor(returns)
        values = torch.tensor(values)
        log_probs = torch.tensor(log_probs)
        logits = torch.tensor(logits)
        controller_prob = torch.tensor(controller_prob)
        meta_controller_probs = torch.tensor(meta_controller_probs)
        meta_controller_values = torch.tensor(meta_controller_values)
        meta_controller_tensor = torch.tensor(meta_controller_tensor)
        done = torch.tensor(done)
        if self.use_lstm :
            core_state = torch.tensor(core_state)
        

        advantages = returns - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        for indices in sampler:
            if recurrent:
                batch_dict = {key: [torch.tensor(obs_dict[key][self.traj_idx[i]:self.traj_idx[i+1]]) for i in indices] for key in obs_dict}
                next_batch_dict = {key: [torch.tensor(next_obs_dict[key][self.traj_idx[i]:self.traj_idx[i+1]]) for i in indices] for key in next_obs_dict}
                
                obs_batch = {key: pad_sequence(val, batch_first=False) for key, val in batch_dict.items()}
                next_obs_batch = {key: pad_sequence(val, batch_first=False) for key, val in next_batch_dict.items()}
                
                action_batch = pad_sequence([actions[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                return_batch = pad_sequence([returns[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                advantage_batch = pad_sequence([advantages[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                values_batch = pad_sequence([values[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                mask = pad_sequence([torch.ones_like(r) for r in return_batch], batch_first=False).flatten(0, 1)
                log_prob_batch = pad_sequence([log_probs[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                logits_batch = pad_sequence([logits[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                controller_prob_batch = pad_sequence([controller_prob[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                meta_controller_prob_batch = pad_sequence([meta_controller_probs[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                meta_controller_value_batch = pad_sequence([meta_controller_values[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                meta_controller_tensor_batch = pad_sequence([meta_controller_tensor[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
                done_batch = pad_sequence([done[self.traj_idx[i]:self.traj_idx[i+1]] for i in indices], batch_first=False).flatten(0, 1)
            else:
                # for key, val in obs_dict.items() :
                #     print(key, val)
                obs_batch = {key: torch.stack([torch.tensor(val[i]) for i in indices], dim=1).squeeze(2) for key, val in obs_dict.items()}
                next_obs_batch = {key: torch.stack([torch.tensor(val[i]) for i in indices], dim=1).squeeze(2) for key, val in next_obs_dict.items()}
                action_batch = actions[indices]
                return_batch = returns[indices]
                advantage_batch = advantages[indices]
                values_batch = values[indices]
                mask = torch.FloatTensor([1])
                log_prob_batch = log_probs[indices]
                logits_batch = logits[indices]
                controller_prob_batch = controller_prob[indices]
                meta_controller_prob_batch = meta_controller_probs[indices]
                meta_controller_value_batch = meta_controller_values[indices]
                meta_controller_tensor_batch = meta_controller_tensor[indices]
                done_batch = done[indices]
                if self.use_lstm :
                    core_state_batch = core_state[indices]

            if self.use_lstm : 
                yield {key: batch.to(self.device) for key, batch in obs_batch.items()}, \
                    {key: batch.to(self.device) for key, batch in next_obs_batch.items()}, \
                    action_batch.to(self.device), \
                    return_batch.to(self.device), \
                    advantage_batch.to(self.device), \
                    values_batch.to(self.device), \
                    mask.to(self.device), \
                    log_prob_batch.to(self.device), \
                    logits_batch.to(self.device), \
                    controller_prob_batch.to(self.device), \
                    meta_controller_prob_batch.to(self.device), \
                    meta_controller_value_batch.to(self.device), \
                    meta_controller_tensor_batch.to(self.device), \
                    done_batch.to(self.device), \
                    core_state_batch.to(self.device)
            else :
                yield {key: batch.to(self.device) for key, batch in obs_batch.items()}, \
                    {key: batch.to(self.device) for key, batch in next_obs_batch.items()}, \
                    action_batch.to(self.device), \
                    return_batch.to(self.device), \
                    advantage_batch.to(self.device), \
                    values_batch.to(self.device), \
                    mask.to(self.device), \
                    log_prob_batch.to(self.device), \
                    logits_batch.to(self.device), \
                    controller_prob_batch.to(self.device), \
                    meta_controller_prob_batch.to(self.device), \
                    meta_controller_value_batch.to(self.device), \
                    meta_controller_tensor_batch.to(self.device), \
                    done_batch.to(self.device), \
                    torch.empty(1)
                      
