import random
import torch
import torch.nn.functional as F
from collections import deque
import numpy as np

class ReplayBuffer:
    def __init__(self, epoch_index=False, max_episodes=10000):
        self.buffer = deque(maxlen=max_episodes)
        self.style_indexes = deque(maxlen=max_episodes)
        self.epoch_indexes = deque(maxlen=max_episodes)
        self.contrast_indexes = deque(maxlen=max_episodes)
        self.current_episode = []
        if epoch_index:
            self.epoch_indexes_label = True
        else:
            self.epoch_indexes_label = False
    
    def add_transition(self, state, done, style_index, epoch_index=None, action=None):
        if action == None:
            self.current_episode.append(state)
        else:
            transition = (state, action)
            self.current_episode.append(transition)
        if done:
            self.buffer.append(self.current_episode)
            self.style_indexes.append(style_index)
            self.current_episode = []
            if epoch_index is not None:
                self.epoch_indexes.append(epoch_index)
                self.contrast_indexes.append(epoch_index*3+style_index)
    
    # crop random length of trajectories in buffer for each batch
    def sample_episodes_random(self, batch_size, min_len=5, max_len=25):
        buffer_len = len(self.buffer)
        if batch_size > buffer_len:
            print(f"Warning: batch_size ({batch_size}) > buffer size ({buffer_len}). Skipping sampling.")
            return None
        sampled_indices = random.sample(range(buffer_len), batch_size)
        sampled_episodes = [self.buffer[i] for i in sampled_indices]
        sampled_style_indexes = [self.style_indexes[i] for i in sampled_indices]
        traj_list = []
        mask_list = []

        for traj in sampled_episodes:
            L = np.random.randint(min_len, max_len + 1) # choose the length of the trajectory to crop
            start_idx = np.random.randint(0, len(traj) - L + 1)
            crop = np.array(traj[start_idx:start_idx + L])
       
            pad_len = max_len - L
            if pad_len > 0:
                crop = np.concatenate([crop, np.zeros((pad_len, crop.shape[1]))], axis=0)
            mask = np.concatenate([np.ones(L), np.zeros(pad_len)], axis=0)
            traj_list.append(crop)
            mask_list.append(mask)
        sampled_episodes = np.stack(traj_list)
        mask_list = np.stack(mask_list)
        
        if self.epoch_indexes_label:
            sampled_epoch_indexes = [self.epoch_indexes[i] for i in sampled_indices]
            return sampled_episodes, sampled_style_indexes, sampled_epoch_indexes
        else:
            return sampled_episodes, sampled_style_indexes, mask_list 
    
    
    # crop same length of trajectories in buffer for each batch  
    def sample_episodes_uniform(self, batch_size, min_len=5, max_len=25):
        buffer_len = len(self.buffer)
        if batch_size > buffer_len:
            print(f"Warning: batch_size ({batch_size}) > buffer size ({buffer_len}). Skipping sampling.")
            return None
        sampled_indices = random.sample(range(buffer_len), batch_size)
        sampled_episodes = [self.buffer[i] for i in sampled_indices]
        sampled_style_indexes = [self.style_indexes[i] for i in sampled_indices]
        crop_length = random.randint(min_len, max_len)
        sampled_episodes = np.array(sampled_episodes)
        start_point = random.randint(0, (max_len-crop_length))
        end_point = start_point + crop_length
        cropped_episodes = sampled_episodes[:, start_point:end_point]
        if self.epoch_indexes_label:
            sampled_epoch_indexes = [self.epoch_indexes[i] for i in sampled_indices]
            return sampled_episodes, sampled_style_indexes, sampled_epoch_indexes
        else:
            return cropped_episodes, sampled_style_indexes
    
    def __len__(self):
        return len(self.buffer)
    
    def clear(self):
        self.buffer.clear()
        self.style_indexes.clear()
        if self.epoch_indexes_label:
            self.epoch_indexes.clear()
            self.contrast_indexes.clear()
        self.current_episode = []
        print("Replay buffer cleared.")


class ReplayBufferDict:
    def __init__(self, max_episodes=10000):
        self.trajectories = []
        self.max_episodes = max_episodes
        self.current_episode = []

    def add_transition(self, state, done, style_index, epoch_index=None):
        self.current_episode.append(state)
        if done:
            traj = {
                'trajectory': np.array(self.current_episode),  # shape: (T, obs_dim)
                'style': style_index,
                'epoch': epoch_index,
                'contrast_index': epoch_index * 3 + style_index if epoch_index is not None else None
            }
            self.trajectories.append(traj)
            self.current_episode = []

            if len(self.trajectories) > self.max_episodes:
                self.trajectories.pop(0)

    def sample_episodes(self, batch_size, contrast=False):
        if len(self.trajectories) < batch_size:
            return None

        sampled = random.sample(self.trajectories, batch_size)
        trajs = [t['trajectory'] for t in sampled]
        styles_idx = [t['style'] for t in sampled]
        epochs_idx = [t['epoch'] for t in sampled]
        contrast_idx = [t['contrast_index'] for t in sampled]
        if contrast:
            return trajs, styles_idx, epochs_idx, contrast_idx
        else:
            return trajs, styles_idx, epochs_idx

    def __len__(self):
        return len(self.trajectories)

    def clear(self):
        self.trajectories.clear()
        self.current_episode = []

        
class ReplayBufferDict_MA:
    def __init__(self, max_episodes=10000):
        self.trajectories = []
        self.max_episodes = max_episodes
        self.current_episode = []

    def add_transition(self, state, done, style_index): # state is a concatenated state of all agents
        self.current_episode.append(state)
        if done:
            traj = {
                'trajectory': np.array(self.current_episode),  # shape: (T, obs_dim)
                'contrast_index': style_index
            }
            self.trajectories.append(traj)
            self.current_episode = []

            if len(self.trajectories) > self.max_episodes:
                self.trajectories.pop(0)

    def sample_episodes(self, batch_size):
        if len(self.trajectories) < batch_size:
            return None

        sampled = random.sample(self.trajectories, batch_size)
        trajs = [t['trajectory'] for t in sampled] # shape: (batch_size, T, obs_dim)
        contrast_idx = [t['contrast_index'] for t in sampled]
        return trajs, contrast_idx

    def __len__(self):
        return len(self.trajectories)

    def clear(self):
        self.trajectories.clear()
        self.current_episode = []

class MemoryBank:
    def __init__(self, feature_dim, bank_size=4096, momentum=0.9, device='cuda'):
        self.bank_size = bank_size
        self.feature_dim = feature_dim
        self.momentum = momentum
        self.device = device
        
        self.memory = torch.randn(bank_size, feature_dim).to(device)
        self.memory = F.normalize(self.memory, dim=1)
        self.ptr = 0

    @torch.no_grad()
    def update(self, features):
        # features shape: [B, feature_dim]
        batch_size = features.shape[0]
        assert batch_size <= self.bank_size, "Batch size exceeds memory bank capacity."

        new_ptr = (self.ptr + batch_size) % self.bank_size
        if new_ptr < self.ptr:
            self.memory[self.ptr:] = features[:self.bank_size - self.ptr]
            self.memory[:new_ptr] = features[self.bank_size - self.ptr:]
        else:
            self.memory[self.ptr:new_ptr] = features
        self.ptr = new_ptr

    def get_memory(self):
        return self.memory