import torch
from torch import nn

class ObserverBase(nn.Module):
    def __init__(self, channels):
        super(ObserverBase, self).__init__()
        self.channels = channels
        if self.channels > 0:  # only weights are used
            self.register_buffer('max_val', torch.zeros((channels, 1), dtype=torch.float32))
            self.register_buffer('min_val', torch.zeros((channels, 1), dtype=torch.float32))
        else:
            self.register_buffer('max_val', torch.zeros((1), dtype=torch.float32))
            self.register_buffer('min_val', torch.zeros((1), dtype=torch.float32))
        self.num_flag = 0
        #print('init max_val', self.max_val)
        #print('init min_val', self.min_val)


    def reset_range(self):
        self.min_val = torch.zeros_like(self.min_val) + 10000
        self.max_val = torch.zeros_like(self.max_val) - 10000


class NineNineObserver(ObserverBase):
    def __init__(self, channels, nine=0.999):
        super(NineNineObserver, self).__init__(channels)
        self.nine = nine

    def forward(self, x):
        if self.channels > 0:
            nine_nine_id = int(x.shape[1] * self.nine)
            indata, _ = torch.sort(indata, dim=1)
            min_val = indata[:, 0]
            max_val = indata[:, nine_nine_id]
        else:
            indata = torch.flatten(x)
            nine_nine_id = int(indata.shape[0] * self.nine)
            indata, _ = torch.sort(indata)

            min_val = indata[0]
            max_val = indata[nine_nine_id]
        self.update_range(min_val, max_val)

    def update_range(self, min_val, max_val):
        min_val = torch.reshape(min_val, self.min_val.shape)
        max_val = torch.reshape(max_val, self.max_val.shape)

        if self.num_flag == 0:
            min_val_new = min_val
            max_val_new = max_val
            self.num_flag += 1
        else:
            min_val_new = torch.min(min_val, self.min_val)
            max_val_new = torch.max(max_val, self.min_val)

        self.min_val.copy_(min_val_new.detach())
        self.max_val.copy_(max_val_new.detach())


class MinMaxObserver(ObserverBase):
    def __init__(self, channels):
        super(MinMaxObserver, self).__init__(channels)

    def forward(self, x):
        if self.channels > 0:
            min_val = torch.min(x, 1, keepdim=True)[0]
            max_val = torch.max(x, 1, keepdim=True)[0]
        else:
            min_val = torch.min(x)
            max_val = torch.max(x)
        self.update_range(min_val, max_val)
        #print('updated max_val', self.max_val)
        #print('updated min_val', self.min_val)



class NormalMinMaxObserver(MinMaxObserver):
    #print("NormalMinMax")
    def __init__(self, channels):
        super(NormalMinMaxObserver, self).__init__(channels)

    def update_range(self, min_val, max_val):
        #print(f"#########################################MIN_VALUE#############################################: {min_val}")
        min_val = torch.reshape(min_val, self.min_val.shape)
        max_val = torch.reshape(max_val, self.max_val.shape)

        if self.num_flag == 0:
            min_val_new = min_val
            max_val_new = max_val
            self.num_flag += 1
        else:
            min_val_new = torch.min(min_val, self.min_val)
            max_val_new = torch.max(max_val, self.min_val)

        self.min_val.copy_(min_val_new.detach())
        self.max_val.copy_(max_val_new.detach())

class MovingAvgMinMaxObserver(MinMaxObserver):
    def __init__(self, channels, momentum=0.1):
        super(MovingAvgMinMaxObserver, self).__init__(channels)
        self.momentum = momentum

    def update_range(self, min_val, max_val):
        min_val = torch.reshape(min_val, self.min_val.shape)
        max_val = torch.reshape(max_val, self.max_val.shape)

        if self.num_flag == 0:
            min_val_new = min_val
            max_val_new = max_val
            self.num_flag += 1
        else:
            min_val_new = self.min_val + (min_val - self.min_val) * self.momentum
            max_val_new = self.max_val + (max_val - self.max_val) * self.momentum

        self.min_val.copy_(min_val_new.detach())
        self.max_val.copy_(max_val_new.detach())

#########################################################################################
#import torch
#from torch import nn
#
#class ObserverBase(nn.Module):
#    def __init__(self, channels):
#        super(ObserverBase, self).__init__()
#        self.channels = channels
#        self.target_size = target_size  # target size for reshaping
#        if self.channels > 0:  # only weights are used
#            self.register_buffer('max_val', torch.zeros((channels, 1), dtype=torch.float32))
#            self.register_buffer('min_val', torch.zeros((channels, 1), dtype=torch.float32))
#        else:
#            self.register_buffer('max_val', torch.zeros((1), dtype=torch.float32))
#            self.register_buffer('min_val', torch.zeros((1), dtype=torch.float32))
#        self.num_flag = 0
#        print('init max_val', self.max_val)
#        print('init min_val', self.min_val)
#
#    def reset_range(self):
#        self.min_val = torch.zeros_like(self.min_val) + 10000
#        self.max_val = torch.zeros_like(self.max_val) - 10000
#
#    def adjust_size(self, x):
#        """Adjust the input size to match target_size"""
#        flattened_x = torch.flatten(x)
#
#        if flattened_x.size(0) > self.target_size:
#            # If the input is larger than target size, slice it
#            return flattened_x[:self.target_size]
#        elif flattened_x.size(0) < self.target_size:
#            # If the input is smaller, pad with zeros
#            padding_size = self.target_size - flattened_x.size(0)
#            return torch.cat([flattened_x, torch.zeros(padding_size, device=flattened_x.device)])
#        return flattened_x
#
#class NineNineObserver(ObserverBase):
#    def __init__(self, channels, nine=0.999, target_size=4096):
#        super(NineNineObserver, self).__init__(channels, target_size=target_size)
#        self.nine = nine
#
#    def forward(self, x):
#        # Adjust the input size
#        adjusted_x = self.adjust_size(x)
#
#        if self.channels > 0:
#            nine_nine_id = int(adjusted_x.shape[0] * self.nine)
#            indata, _ = torch.sort(adjusted_x)
#            min_val = indata[0]
#            max_val = indata[nine_nine_id]
#        else:
#            indata = torch.flatten(adjusted_x)
#            nine_nine_id = int(indata.shape[0] * self.nine)
#            indata, _ = torch.sort(indata)
#
#            min_val = indata[0]
#            max_val = indata[nine_nine_id]
#        self.update_range(min_val, max_val)
#
#    def update_range(self, min_val, max_val):
#        min_val = torch.reshape(min_val, self.min_val.shape)
#        max_val = torch.reshape(max_val, self.max_val.shape)
#
#        if self.num_flag == 0:
#            min_val_new = min_val
#            max_val_new = max_val
#            self.num_flag += 1
#        else:
#            min_val_new = torch.min(min_val, self.min_val)
#            max_val_new = torch.max(max_val, self.min_val)
#
#        self.min_val.copy_(min_val_new.detach())
#        self.max_val.copy_(max_val_new.detach())
#
#
#class MinMaxObserver(ObserverBase):
#    def __init__(self, channels):
#        super(MinMaxObserver, self).__init__(channels)
#
#    def forward(self, x):
#        # Adjust the input size
#        adjusted_x = self.adjust_size(x)
#
#        if self.channels > 0:
#            min_val = torch.min(adjusted_x, 0, keepdim=True)[0]
#            max_val = torch.max(adjusted_x, 0, keepdim=True)[0]
#        else:
#            min_val = torch.min(adjusted_x)
#            max_val = torch.max(adjusted_x)
#        self.update_range(min_val, max_val)
#
#class NormalMinMaxObserver(MinMaxObserver):
#    def __init__(self, channels):
#        super(NormalMinMaxObserver, self).__init__(channels)
#
#    def update_range(self, min_val, max_val):
#    # Check if min_val and max_val are compatible with self.min_val.shape
#        if min_val.numel() != self.min_val.numel():
#            # If not, we need to reshape min_val and max_val accordingly
#            min_val = torch.reshape(min_val, (-1,) + self.min_val.shape[1:])
#            max_val = torch.reshape(max_val, (-1,) + self.max_val.shape[1:])
#        
#        if min_val.shape != self.min_val.shape:
#            min_val = min_val.expand_as(self.min_val)
#            max_val = max_val.expand_as(self.max_val)
#    
#        if self.num_flag == 0:
#            min_val_new = min_val
#            max_val_new = max_val
#            self.num_flag += 1
#        else:
#            min_val_new = torch.min(min_val, self.min_val)
#            max_val_new = torch.max(max_val, self.max_val)
#    
#        self.min_val.copy_(min_val_new.detach())
#        self.max_val.copy_(max_val_new.detach())
#
#class MovingAvgMinMaxObserver(MinMaxObserver):
#    def __init__(self, channels, momentum=0.1, target_size=4096):
#        super(MovingAvgMinMaxObserver, self).__init__(channels, target_size=target_size)
#        self.momentum = momentum
#
#    def update_range(self, min_val, max_val):
#        min_val = torch.reshape(min_val, self.min_val.shape)
#        max_val = torch.reshape(max_val, self.max_val.shape)
#
#        if self.num_flag == 0:
#            min_val_new = min_val
#            max_val_new = max_val
#            self.num_flag += 1
#        else:
#            min_val_new = self.min_val + (min_val - self.min_val) * self.momentum
#            max_val_new = self.max_val + (max_val - self.max_val) * self.momentum
#
#        self.min_val.copy_(min_val_new.detach())
#        self.max_val.copy_(max_val_new.detach())
#
##################################################################################