import torch
import torch.nn as nn
from .compress_methods import compress_topk, compress_topk_sync, quantize_custom, quantize_dequantize_ACSGD,natural_quantize, natural_compress
from monitor_logger import monitor
import torch.distributed as dist

class Compressor(nn.Module):
    
    compressors_dict = {
        'id': None,
        'topk': compress_topk,
        'topk_sync': compress_topk_sync,
        'quantization_simple': quantize_custom,
        'natural_quantize': natural_quantize,
        'quantization_acsgd': quantize_dequantize_ACSGD,
        'natural_compress': natural_compress
    }
    
    def _get_compressor(self, compressor_id, **compressor_params):
        if compressor_id in self.compressors_dict:
            if compressor_id == 'id':
                compressor = lambda x,ctx: x
            else:
                compressor = lambda x,ctx: self.compressors_dict[compressor_id](x, ctx, **compressor_params)
        else:
            raise ValueError(f'Compressor with name {compressor_id} Not found')
        return compressor
    
    def __init__(self, input_shape, forward='id', forward_params={}, backward='id', backward_params={}, forward_EF=False, backward_EF=False, forward_EF_method=None, backward_EF_method=None):
        super(Compressor, self).__init__()
        self.input_shape = input_shape
        self.forward_func = forward
        self.forward_params = forward_params
        self.backward_func = backward
        self.backward_params = backward_params
        self.forward_EF = forward_EF
        self.backward_EF = backward_EF
        self.forward_EF_method = forward_EF_method
        self.backward_EF_method = backward_EF_method
        
        # Get the device where this module is running
        self.device = next(self.parameters()).device if list(self.parameters()) else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # print("self.device", self.device)
        
        # Initialize buffers with proper tensor instead of None
        if forward_EF:
            self.EF_forward_buffer = torch.zeros(input_shape, device=self.device)
        else:
            self.EF_forward_buffer = None
        if backward_EF:
            self.EF_backward_buffer = torch.zeros(input_shape, device=self.device)
        else:
            self.EF_backward_buffer = None
        
        forward_compression = self._get_compressor(forward, **forward_params)
        backward_compression = self._get_compressor(backward, **backward_params)
        
        if self.forward_EF or self.backward_EF:
            if self.forward_EF_method == "AQSGD":
                self.EF_forward_buffer = {}
            elif self.forward_EF_method == "EF21" and self.backward_EF_method == "EF21":
                print("Using EF21")
                self.compressor = EF21Compressor(
                    forward_compression, self.EF_forward_buffer, self.forward_EF,
                    backward_compression, self.EF_backward_buffer, self.backward_EF
                )
            elif self.forward_EF_method == 'AQSGD' and self.backward_EF_method == 'EF21':
                print("Using AQSGD_EF21")
                self.compressor = AQSGD_EF21Compressor(
                    forward_compression, self.EF_forward_buffer,self.forward_EF,
                backward_compression, self.EF_backward_buffer,self.backward_EF
            )
            elif self.forward_EF_method == 'EF' and self.backward_EF_method == 'EF':
                print("Using EF")
                self.compressor = EFCompressor(
                    forward_compression, self.EF_forward_buffer,self.forward_EF,
                    backward_compression, self.EF_backward_buffer,self.backward_EF
                )
            else:
                raise ValueError("Unsupported forward and backward EF methods")
        else:
            self.compressor = BaseCompressor(
                forward_compression, backward_compression
            )
        

    def forward(self, x, indices=None, compress=True):
        out = self.compressor(x, indices, compress)
        return out

def BaseCompressor(compressor_forward, compressor_backward):
    class myCompressor(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, indices=None, compress=True):
            ctx.indices = indices
            ctx.compress = compress
            
            if compressor_forward is None:
                return input
                
            x = input.detach().clone()
            
            if compress:
                c_x = compressor_forward(x, ctx)
            else:
                c_x = x
            return c_x

        @staticmethod
        def backward(ctx, grad_output):
            if compressor_backward is None:
                return grad_output
            x = grad_output.detach().clone()
            if ctx.compress:
                c_x = compressor_backward(x, ctx)
            else:
                c_x = x
            return c_x, None, None

    return myCompressor.apply

def EFCompressor(compressor_forward, fwd_error_buffer, forward_EF, compressor_backward, bckwd_error_buffer, backward_EF):
    
    class EFCompression(torch.autograd.Function):

            @staticmethod
            def forward(ctx, input, indices=None, compress=True):
                if compressor_forward is None:
                    return input
                x = input.detach().clone()
                ctx.compress = compress
                if forward_EF:
                    if compress:
                        e = fwd_error_buffer  # get error from buffer
                        c_x = compressor_forward(x + e,ctx)  # compress x+e
                        e = e + x - c_x  # update e
                        fwd_error_buffer[:] = e  # set e back to buffer
                    else:
                        c_x = x
                else:
                    if compress:
                        c_x = compressor_forward(x,ctx)
                    else:
                        c_x = x
                return c_x


            @staticmethod
            def backward(ctx, grad_output):
                if compressor_backward is None:
                    return grad_output
                x = grad_output.detach().clone()
                if backward_EF:
                    if ctx.compress:
                        e = bckwd_error_buffer # get error from dict
                        c_x = compressor_backward(x + e,ctx) # compress x+e
                        e = e + x - c_x # update e
                        bckwd_error_buffer[:] = e # set e back to dict
                    else:
                        c_x = x
                else:
                    if ctx.compress:
                        c_x = compressor_backward(x,ctx)
                    else:
                        c_x = x
                return c_x, None, None

    return EFCompression.apply

def EF21Compressor(compressor_forward, fwd_error_buffer, forward_EF, compressor_backward, bckwd_error_buffer, backward_EF):
    
    class EF21Compression(torch.autograd.Function):
        last_indices = None
        @staticmethod
        def forward(ctx, input, indices=None, compress=True):
            if compressor_forward is None:
                return input
            x = input.detach().clone()
            ctx.indices = indices
            ctx.compress = compress
            if forward_EF:
                if compress:
                    A0 = fwd_error_buffer
                    x_slice = tuple(slice(0, min(s1, s2)) for s1, s2 in zip(x.shape, A0.shape)) 
                    diff = x[x_slice] - A0[x_slice]  
                    # monitor.log_metric(f"train/forward_diff_norm", torch.norm(diff).item())
                    c_x = compressor_forward(diff,ctx)  # compress P1 - A0
                    monitor.debug(f"train/c_x_norm", torch.norm(c_x).item())
                    result = torch.zeros_like(x)
                    result[x_slice] = A0[x_slice] + c_x  
                    # monitor.log_metric(f"train/forward_result_x_diff_norm", torch.norm(result - x).item())
                    # monitor.log_metric(f"train/forward_result_x_diff_ratio", torch.norm(result - x).item() / torch.norm(x).item())
                    # print("result - x norm: ", (result - x).norm())
                    # monitor.info(f"rank {dist.get_rank()} train/result_x_diff_norm: {torch.norm(result - x).item()}")
                        
                    fwd_error_buffer[x_slice] = result[x_slice]
                    monitor.debug(f"result norm: {torch.norm(result)}, device: {result.device}")
                    return result
                else:
                    c_x = x
            else:
                if compress:
                    c_x = compressor_forward(x,ctx)
                else:
                    c_x = x
            return c_x

        @staticmethod
        def backward(ctx, grad_output):
            if compressor_backward is None:
                return grad_output
            x = grad_output.detach().clone()
            #ctx.indices = indices
            #ctx.compress = compress
            if backward_EF:
                if ctx.compress:
                    A0 = bckwd_error_buffer  
                    x_slice = tuple(slice(0, min(s1, s2)) for s1, s2 in zip(x.shape, A0.shape))
                    diff = x[x_slice] - A0[x_slice]  
                    # monitor.debug(f"backward diff norm: {torch.norm(diff)}, device: {diff.device}")
                    # monitor.log_metric(f"train/backward_diff_norm", torch.norm(diff).item())
                    c_x = compressor_backward(diff,ctx)  # compress diff of activations                        
                    monitor.debug(f"backward c_x norm: {torch.norm(c_x)}, device: {c_x.device}")
                    result = torch.zeros_like(x)
                    result[x_slice] = A0[x_slice] + c_x 
                    # monitor.log_metric(f"train/backward_result_x_diff_norm", torch.norm(result - x).item())
                    # monitor.log_metric(f"train/backward_result_x_diff_ratio", torch.norm(result - x).item() / torch.norm(x).item())
                        
                    bckwd_error_buffer[x_slice] = result[x_slice]
                    monitor.debug(f"backward result norm: {torch.norm(result)}, device: {result.device}")
                    return result, None, None
                else:
                    c_x = x
            else:
                if ctx.compress:
                    c_x = compressor_backward(x,ctx)
                else:
                    c_x = x
            return c_x, None, None

    return EF21Compression.apply

def AQSGD_EF21Compressor(compressor_forward, fwd_error_buffer, forward_EF, compressor_backward, bckwd_error_buffer, backward_EF):
    class AQSGD_EF21Compression(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, indices=None, compress=True):
            if compressor_forward is None:
                return input
            x = input.detach().clone()
            ctx.compress = compress
            if forward_EF:
                if compress:
                    # Initialize buffer for new indices with shape matching single sample
                    for idx in indices:
                        if idx not in fwd_error_buffer:
                            fwd_error_buffer[idx] = torch.zeros_like(x[0])
                    
                    # Get error buffer for all indices
                    A0 = torch.stack([fwd_error_buffer[idx] for idx in indices])
                    
                    c_x = compressor_forward(x - A0,ctx)  # compress P1 - A0
                    result = A0 + c_x  # Update with compressed difference
                    
                    # Update buffer for all indices
                    for i, idx in enumerate(indices):
                        fwd_error_buffer[idx] = result[i]
                    
                    return result
                else:
                    c_x = x
            else:
                if compress:
                    c_x = compressor_forward(x,ctx)
                else:
                    c_x = x
            return c_x

        @staticmethod
        def backward(ctx, grad_output):
            if compressor_backward is None:
                return grad_output
            x = grad_output.detach().clone()
            if backward_EF:
                if ctx.compress:
                    A0 = bckwd_error_buffer  # get old activations from dict
                    x_slice = tuple(slice(0, min(s1, s2)) for s1, s2 in zip(x.shape, A0.shape))
                    diff = x[x_slice] - A0[x_slice] 
                    c_x = compressor_backward(diff,ctx)  # compress diff of activations
                        
                    result = torch.zeros_like(x)
                    result[x_slice] = A0[x_slice] + c_x 
                        
                    bckwd_error_buffer[x_slice] = result[x_slice]
                    return result, None, None
                else:
                    c_x = x
            else:
                if ctx.compress:
                    c_x = compressor_backward(x,ctx)
                else:
                    c_x = x
            return c_x, None, None

    return AQSGD_EF21Compression.apply


if __name__ == "__main__":
    x = torch.randn(10, 10, requires_grad=True)
    compressor = EF21Compressor()
    output = compressor(x)
    print(f"Output requires_grad: {output.requires_grad}")
    print(f"Output grad_fn: {output.grad_fn}")

    loss = output.sum()  
    print(f"Loss requires_grad: {loss.requires_grad}")
    loss.backward()