# turborag/data/cache_handler.py
import torch
import abc
from .kSVD import kSVD
from .utils.metrics import QueryMetrics

class_registry = {}

def auro_register(cls):
    class_registry[cls.__name__] = cls
    return cls

class CacheHandler(abc.ABC):
    def __init__(self, device, ksvd_compressor=None):
        self.device = device
        self.ksvd_compressor = ksvd_compressor

    @abc.abstractmethod
    def retrieve(self, node, metrics):
        pass

    def get_handler(cache_type):
        if cache_type == "ksvd":
            return class_registry['KSVDCacheHandler']
        elif cache_type == "batched_ksvd":
            return class_registry['BatchedKSVDCacheHandler']
        elif cache_type == "layer_batched_ksvd":
            return class_registry['LayerBatchedKSVDCacheHandler']
        elif cache_type == "kvcache":
            return class_registry['KVCacheHandler']
        else:
            return None

@auro_register
class KSVDCacheHandler(CacheHandler):
    def retrieve(self, node, metrics):
        indices_file_path = node.metadata["indices_file_path"]
        values_file_path = node.metadata["values_file_path"]
        
        metrics.start('to_ram')
        indices = torch.load(indices_file_path, weights_only=True, map_location='cpu')
        values = torch.load(values_file_path, weights_only=True, map_location='cpu')
        metrics.stop('to_ram')

        metrics.start('to_gpu')
        indices = indices.to(self.device)
        values = values.to(self.device, dtype=torch.float16)
        metrics.stop('to_gpu')

        metrics.start('reconstruction')
        kvcache = self.ksvd_compressor.reconstruct_kSVD_with_omp_v0(indices, values)
        metrics.stop('reconstruction')
        
        return kvcache

@auro_register
class BatchedKSVDCacheHandler(CacheHandler):
    def __init__(self, device, ksvd_compressor=None, batch_size=32):
        super().__init__(device, ksvd_compressor)
        self.batch_indices = []
        self.batch_values = []
        self.batch_size = batch_size

    def retrieve(self, node, metrics):
        indices_file_path = node.metadata["indices_file_path"]
        values_file_path = node.metadata["values_file_path"]

        metrics.start('to_ram')
        indices = torch.load(indices_file_path, weights_only=True, map_location='cpu')
        values = torch.load(values_file_path, weights_only=True, map_location='cpu')
        metrics.stop('to_ram')

        metrics.start('to_gpu')
        self.batch_indices.append(indices)
        self.batch_values.append(values)
        metrics.stop('to_gpu')

        if len(self.batch_indices) % self.batch_size == 0:
            return self.process_batch(metrics)
        
        return None # Indicate that no cache is ready yet

    def process_batch(self, metrics):
        if not self.batch_indices:
            return None

        metrics.start('to_gpu')
        batch_indices = torch.stack(self.batch_indices, dim=0).to(self.device)
        batch_values = torch.stack(self.batch_values, dim=0).to(self.device, dtype=torch.float16)
        metrics.stop('to_gpu')

        metrics.start('reconstruction')
        past_kv_chunk, group_chunk = self.ksvd_compressor.reconstruct_and_stack_kSVD_with_omp_v0_batched(batch_indices, batch_values)
        metrics.stop('reconstruction')
        
        # Clear batches
        self.batch_indices = []
        self.batch_values = []
        
        return (past_kv_chunk, group_chunk)

    def flush(self, metrics):
        """Process any remaining items in the batch."""
        return self.process_batch(metrics)


@auro_register
class LayerBatchedKSVDCacheHandler(CacheHandler):
    def __init__(self, device, ksvd_compressor=None, batch_size=32, seq_len=514):
        super().__init__(device, ksvd_compressor)
        self.batch_indices = []
        self.batch_values = []
        
        self.batch_size = batch_size # batch size only used for non-layer wise batcing 근데 그냥 batching에서도 없어도 될 듯
        self.seq_len = seq_len

    def retrieve(self, node, metrics):
        indices_file_path = node.metadata["indices_file_path"]
        values_file_path = node.metadata["values_file_path"]

        metrics.start('to_ram')
        indices = torch.load(indices_file_path, weights_only=True, map_location='cpu')
        values = torch.load(values_file_path, weights_only=True, map_location='cpu')
        metrics.stop('to_ram')

        metrics.start('to_gpu')
        self.batch_indices.append(indices)
        self.batch_values.append(values)
        metrics.stop('to_gpu')

        if (self.batch_indices[0].shape[1] // self.seq_len) * len(self.batch_indices) >= self.batch_size : 
            metrics.start('to_gpu')
            self.batch_indices = torch.cat(self.batch_indices, dim=1)
            self.batch_values = torch.cat(self.batch_values, dim=1)
            
            self.buffer_indices = self.batch_indices[:, self.batch_size * self.seq_len:, :]
            self.buffer_values = self.batch_values[:, self.batch_size * self.seq_len:, :]
            self.batch_indices = self.batch_indices[:, :self.batch_size * self.seq_len, :]
            self.batch_values = self.batch_values[:, :self.batch_size * self.seq_len, :]
            metrics.stop('to_gpu')
            
            cache =  self.process_batch(metrics)
            
            metrics.start('to_gpu')
            if self.buffer_indices.shape[1]==0:
                self.buffer_indices = []
                self.batch_values = []
            else:
                self.batch_indices = [self.buffer_indices]
                self.batch_values = [self.buffer_values]
            del self.buffer_indices
            del self.buffer_values
            metrics.stop('to_gpu')
            return cache

        
        return None # Indicate that no cache is ready yet

    def process_batch(self, metrics):
        metrics.start('to_gpu')
        # batch_indices = torch.stack(self.batch_indices, dim=0).to(self.device)
        # batch_values = torch.stack(self.batch_values, dim=0).to(self.device, dtype=torch.float16)
        batch_indices = self.batch_indices.to(self.device)
        batch_values = self.batch_values.to(self.device, dtype=torch.float16)
        metrics.stop('to_gpu')

        metrics.start('reconstruction')
        
        # past_kv_chunk, group_chunk = self.ksvd_compressor.reconstruct_and_stack_kSVD_with_omp_v0_batched(batch_indices, batch_values
        past_kv_chunk = self.ksvd_compressor.reconstruct_kSVD_with_omp_v0_layer_wise_batch(batch_indices, batch_values, batch_indices.shape[1]//self.seq_len)
        torch.cuda.synchronize()
        metrics.stop('reconstruction')
        # Clear batches
        self.batch_indices = []
        self.batch_values = []
        return (past_kv_chunk)  
        # return (past_kv_chunk, group_chunk)

    def flush(self, metrics):
        """Process any remaining items in the batch."""
        if len(self.batch_values)==0:
            return None
        self.batch_indices = torch.cat(self.batch_indices, dim=1)
        self.batch_values = torch.cat(self.batch_values, dim=1)
        
        return self.process_batch(metrics)

@auro_register
class KVCacheHandler(CacheHandler):
    def retrieve(self, node, metrics):
        cache_file_path = node.metadata["kvcache_file_path"]

        metrics.start('to_ram')
        kvcache = torch.load(cache_file_path, weights_only=True, map_location='cpu')
        metrics.stop('to_ram')

        metrics.start('to_gpu')
        kvcache = tuple(tuple(t.to(self.device) for t in layer) for layer in kvcache)
        metrics.stop('to_gpu')
        
        return kvcache 