import einops
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import IterableDataset, DataLoader

from transformer_lens.hook_points import HookedRootModule
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils

import accelerate

import numpy as np

import time

from datasets import load_dataset
from sae import TrainingConfig

import os
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


class ReplayBuffer:
    def __init__(self, buffer_size: int, batch_size: int, device: str, reconstruct: bool = True):
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.buffer = None
        self.position = 0
        self.used_positions = set()
        self.device = device
        self.reconstruct = reconstruct
        
    def add(self, activations: tuple[torch.Tensor, torch.Tensor] | torch.Tensor):
        if self.reconstruct:
            activations = activations.cpu()
            
            if self.buffer is None:
                self.buffer = torch.zeros((self.buffer_size,) + activations.shape[1:], 
                                    dtype=activations.dtype)
            
            if self.used_positions:
                positions = list(self.used_positions)[:activations.shape[0]]
                self.buffer[positions] = activations[:len(positions)]
                self.used_positions -= set(positions)
                
                remaining = activations[len(positions):]
                if len(remaining) > 0:
                    space_left = self.buffer_size - self.position
                    if space_left >= len(remaining):
                        new_positions = list(range(self.position, self.position + len(remaining)))
                        self.buffer[self.position:self.position + len(remaining)] = remaining
                        self.position += len(remaining)
                    else:
                        new_positions = list(range(len(remaining)))
                        self.position = len(remaining)
                        self.buffer[:len(remaining)] = remaining
                    self.used_positions -= set(new_positions)
            else:
                batch_size = activations.shape[0]
                space_left = self.buffer_size - self.position
                
                if space_left >= batch_size:
                    new_positions = list(range(self.position, self.position + batch_size))
                    self.buffer[self.position:self.position + batch_size] = activations
                    self.position += batch_size
                else:
                    new_positions = list(range(self.position, self.buffer_size)) + list(range(batch_size - space_left))
                    self.buffer[self.position:] = activations[:space_left]
                    self.buffer[:batch_size - space_left] = activations[space_left:]
                    self.position = batch_size - space_left
                self.used_positions -= set(new_positions)
        else:
            activations_a = activations[0].cpu()
            activations_b = activations[1].cpu().to(activations_a.dtype)
            
            if self.buffer is None:
                self.buffer = torch.zeros((2, self.buffer_size,) + activations_a.shape[1:], 
                                    dtype=activations_a.dtype)
            
            if self.used_positions:
                positions = list(self.used_positions)[:activations_a.shape[0]]
                self.buffer[0, positions] = activations_a[:len(positions)]
                self.buffer[1, positions] = activations_b[:len(positions)]
                self.used_positions -= set(positions)
                
                remaining_a = activations_a[len(positions):]
                remaining_b = activations_b[len(positions):]
                if len(remaining_a) > 0:
                    space_left = self.buffer_size - self.position
                    if space_left >= len(remaining_a):
                        new_positions = list(range(self.position, self.position + len(remaining_a)))
                        self.buffer[0, self.position:self.position + len(remaining_a)] = remaining_a
                        self.buffer[1, self.position:self.position + len(remaining_b)] = remaining_b
                        self.position += len(remaining_a)
                    else:
                        new_positions = list(range(len(remaining_a)))
                        self.position = len(remaining_a)
                        self.buffer[0, :len(remaining_a)] = remaining_a
                        self.buffer[1, :len(remaining_a)] = remaining_a
                    self.used_positions -= set(new_positions)
            else:
                batch_size = activations_a.shape[0]
                space_left = self.buffer_size - self.position
                
                if space_left >= batch_size:
                    new_positions = list(range(self.position, self.position + batch_size))
                    self.buffer[0, self.position:self.position + batch_size] = activations_a
                    self.buffer[1, self.position:self.position + batch_size] = activations_b
                    self.position += batch_size
                else:
                    new_positions = list(range(self.position, self.buffer_size)) + list(range(batch_size - space_left))
                    self.buffer[0, self.position:] = activations_a[:space_left]
                    self.buffer[1, :batch_size - space_left] = activations_b[space_left:]
                    self.position = batch_size - space_left
                self.used_positions -= set(new_positions)
            
    def sample(self) -> torch.Tensor:
        if self.buffer is None:
            raise RuntimeError("Buffer is empty")
        
        available_positions = set(range(min(self.position, self.buffer_size))) - self.used_positions
        if len(available_positions) < self.batch_size:
            self.used_positions.clear()
            available_positions = set(range(min(self.position, self.buffer_size)))
        
        sample_positions = torch.tensor(list(available_positions))[torch.randperm(len(available_positions))[:self.batch_size]]
        self.used_positions.update(sample_positions.tolist())
        
        if self.reconstruct:
            return self.buffer[sample_positions].to(self.device)
        else:
            return (self.buffer[0, sample_positions].to(self.device), self.buffer[1, sample_positions].to(self.device))

    def is_full(self) -> bool:
        return self.position >= self.buffer_size
    
    def get_mean_std(self):
        return self.buffer.mean().item(), self.buffer.std(-1).mean().item()


class DistributedActivationsStore:
    def __init__(
        self,
        n_ctx: int,
        cfg: TrainingConfig,
        num_processes = 1,
    ):
        if isinstance(cfg.hook_point, tuple):
            self.hook_point = (
                utils.get_act_name(cfg.hook_point[0], cfg.layer),
                utils.get_act_name(cfg.hook_point[1], cfg.layer)
            )
        else:
            self.hook_point = utils.get_act_name(cfg.hook_point, cfg.layer)
        self.context_size = min(cfg.seq_len, n_ctx)
        self.model_batch_size = cfg.model_batch_size
        self.device = cfg.device
        self.num_batches_in_buffer = cfg.num_batches_in_buffer
        self.cfg = cfg
        
        # Initialize replay buffer with 256 sequences
        buffer_size = 256 * cfg.seq_len // max(1, num_processes)
        self.replay_buffer = ReplayBuffer(
            buffer_size=buffer_size,
            batch_size=cfg.batch_size,
            device=cfg.device,
            reconstruct= (not isinstance(cfg.hook_point, tuple))
        )

    def fill_buffer(self, model, dataloader):
        # Initial fill of buffer
        print("Filling buffer...")
        self._fill_buffer(model, dataloader)
        print("Buffer filled.")
        self.mean, self.std = self.replay_buffer.get_mean_std()

    def get_batch_tokens(self, model: HookedTransformer, dataloader):
        all_tokens = []
        while len(all_tokens) < self.model_batch_size * (self.context_size+1):
            try:
                batch = next(dataloader)
            except Exception as e:
                print(f"Error loading batch: {e}")
                time.sleep(10)
                continue
            tokens = model.to_tokens(batch, truncate=True, move_to_device=True, prepend_bos=True).squeeze(0)
            flat_tokens = tokens.flatten()
            flat_tokens = flat_tokens[flat_tokens != model.tokenizer.pad_token_type_id]  
            flat_tokens = flat_tokens[flat_tokens != model.tokenizer.convert_tokens_to_ids('<|endoftext|>')]  
            all_tokens.extend(flat_tokens.detach().cpu().tolist())
        
        token_tensor = torch.tensor(all_tokens, dtype=torch.long, device=self.device)
        token_tensor = token_tensor[:self.model_batch_size * (self.context_size + 1)]

        return token_tensor.view(self.model_batch_size, self.context_size + 1)

    def get_activations(self, batch_tokens: torch.Tensor, model: HookedTransformer): 
        if not isinstance(self.device, str):
            device = self.device.type 
        if isinstance(self.hook_point, tuple):
            with torch.no_grad():
                with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
                    _, cache = model.run_with_cache(
                        batch_tokens,
                        names_filter=[*self.hook_point],
                        stop_at_layer=self.cfg.layer + 1,
                    )
            return (cache[self.hook_point[0]][:, 1:, :], cache[self.hook_point[1]][:, 1:, :])  # drop BOS
        else:
            with torch.no_grad():
                with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
                    _, cache = model.run_with_cache(
                        batch_tokens,
                        names_filter=[self.hook_point],
                        stop_at_layer=self.cfg.layer + 1,
                    )
            return cache[self.hook_point][:, 1:, :]  # drop BOS

    def _fill_buffer(self, model: HookedTransformer, iter_dataloader):
        while not self.replay_buffer.is_full():
            batch_tokens = self.get_batch_tokens(model, iter_dataloader)
            activations = self.get_activations(batch_tokens, model)
            if isinstance(self.hook_point, tuple):
                activations = (activations[0].reshape(-1, self.cfg.act_size), activations[1].reshape(-1, self.cfg.act_size))  # drop BOS
            else:
                activations = activations.reshape(-1, self.cfg.act_size)
            self.replay_buffer.add(activations)

    def get_batch(self, model: HookedTransformer, distributed_dataloader, add_new: bool = True):
        outp = self.replay_buffer.sample()
        outp = outp
        if add_new:
            new_activations = self.get_activations(self.get_batch_tokens(model, distributed_dataloader), model)
            if isinstance(self.hook_point, tuple):
                new_activations = (new_activations[0].reshape(-1, self.cfg.act_size), new_activations[1].reshape(-1, self.cfg.act_size)) 
            else:
                new_activations = new_activations.reshape(-1, self.cfg.act_size)
            self.replay_buffer.add(new_activations)
        return outp
    
    def get_processed_batch(self, model: HookedTransformer, batch, add_new: bool = True):
        outp = self.replay_buffer.sample()
        outp = outp
        if add_new:
            new_activations = self.get_activations(batch, model)
            if isinstance(self.hook_point, tuple):
                new_activations = (new_activations[0].reshape(-1, self.cfg.act_size), new_activations[1].reshape(-1, self.cfg.act_size)) 
            else:
                new_activations = new_activations.reshape(-1, self.cfg.act_size)
            self.replay_buffer.add(new_activations)
        return outp