from torch import Tensor
import torch
import torch.nn as nn
import bisect
import gc
import os
import numpy as np

class SFNeuron(nn.Module):
    def __init__(self, scale_p = 1., scale_n = 1., place = None, times = None):
        super(SFNeuron, self).__init__()
        self.scale_p = scale_p
        self.scale_n = scale_n
        self.place = place
        self.times = times
        self.t = 0
        self.neuron = None

    def forward_step(self, x):
        pos_neuron = torch.where(x > 0, x / self.scale_p + 0.5, torch.tensor(0, device=x.device))
        pos_neuron = torch.where(
            pos_neuron < self.times, 
            torch.floor(pos_neuron), # [0, 1, ..., n - 1] <-- [0, n)
            self.times + torch.pow(2, torch.frexp(pos_neuron - self.times + 1)[1] - 1) - 1,
            # [n, n + 1, n + 3, ..., n + 2^n - 1] <- [n, +∞)
        )
        pos_neuron = torch.where(
            pos_neuron > self.times + 2 ** self.times - 1,
            torch.tensor(self.times + 2 ** self.times - 1, device=x.device),
            pos_neuron,
        )

        neg_neuron = torch.where(x < 0, -(x / self.scale_n - 0.5), torch.tensor(0, device=x.device))
        neg_neuron = torch.where(
            neg_neuron < self.times, 
            torch.floor(neg_neuron), # [0, 1, ..., n - 1] <-- [0, n)
            self.times + torch.pow(2, torch.frexp(neg_neuron - self.times + 1)[1] - 1) - 1,
            # [n, n + 1, n + 3, ..., n + 2^n - 1] <- [n, +∞)
        )
        neg_neuron = torch.where(
            neg_neuron > self.times + 2 ** self.times - 1,
            torch.tensor(self.times + 2 ** self.times - 1, device=x.device),
            neg_neuron,
        )
        
        fire = pos_neuron * self.scale_p - neg_neuron * self.scale_n

        return fire

    def forward_linear(self, x):
        pos_neuron = torch.where(x > 0, x / self.scale_p + 0.5, torch.tensor(0, device=x.device))
        pos_neuron = torch.where(
            pos_neuron < self.times, 
            torch.floor(pos_neuron), # [0, 1, ..., n - 1] <-- [0, n)
            torch.tensor(self.times, device=x.device)
        )

        neg_neuron = torch.where(x < 0, -(x / self.scale_n - 0.5), torch.tensor(0, device=x.device))
        neg_neuron = torch.where(
            neg_neuron < self.times, 
            torch.floor(neg_neuron), # [0, 1, ..., n - 1] <-- [0, n)
            torch.tensor(self.times, device=x.device)
        )
        
        fire = pos_neuron * self.scale_p - neg_neuron * self.scale_n

        return fire

    def forward_exp(self, x):
        pos_neuron = torch.where(x > 0, x / self.scale_p + 0.5, torch.tensor(0, device=x.device))
        pos_neuron = torch.where(
            pos_neuron < 2 ** self.times, 
            torch.pow(2, torch.frexp(pos_neuron)[1] - 1), # [0, 1, ..., 2^{n - 2}] <-- [0, 2^{n - 1})
            torch.tensor(2 ** self.times, device=x.device)
        )

        neg_neuron = torch.where(x < 0, -(x / self.scale_n - 0.5), torch.tensor(0, device=x.device))
        neg_neuron = torch.where(
            neg_neuron < 2 ** self.times, 
            torch.pow(2, torch.frexp(neg_neuron)[1] - 1), # [0, 1, ..., 2^{n - 2}] <-- [0, 2^{n - 1})
            torch.tensor(2 ** self.times, device=x.device)
        )
        
        fire = pos_neuron * self.scale_p - neg_neuron * self.scale_n

        return fire

    def forward(self, x):
        if self.t == 0:
            self.neuron = torch.zeros_like(x)
        self.neuron += x

        fire = self.forward_step(self.neuron)
        
        self.neuron -= fire
        self.t += 1
        return fire
    
    def reset(self):
        self.t = 0
        self.neuron = None

def replace_testneuron_by_sfneuron(model, args):
    for name, module in model._modules.items():
        if hasattr(module, "_modules"):
            model._modules[name] = replace_testneuron_by_sfneuron(module,args)
        if module.__class__.__name__.lower() == 'testneuron':
            place = model._modules[name].place
            p, n = float(model._modules[name].scale_p[0]), float(model._modules[name].scale_n[0])

            if "fc" not in place and (place == 'q' or place == 'k' or place == 'v'):
                times = args.qkv_num
                sp = p * args.lambda
                sn = n * args.lambda
            elif 's' in place:
                times = args.softmax_num
                if args.softmax_p is not None:
                    sp = args.softmax_p
                    sn = args.softmax_p
                else:
                    sp = p * args.lambda
                    sn = n * args.lambda
            else:
                times = args.linear_num
                sp = p * args.lambda
                sn = n * args.lambda

            """
            if 'vit' in args.model and 'fc2' in place:
                sp = 0.25
                sn = 0.08
            """

            model._modules[name] = SFNeuron(
                scale_p = sp,
                scale_n = sn,
                place = place,
                times = times,
            )
    return model

class MyTestPlace(nn.Module):
    def __init__(self, act_shape=None, place=None, save_dir="./stats", MAX_GB=16):
        super(MyTestPlace, self).__init__()
        self.act_shape = act_shape
        self.place = place
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        
        self.buffer = []
        self.max_bytes = MAX_GB * 1024 ** 3
        self.buffer_size = 0
        self.step = 0

    def forward(self, x):
        if self.place != '0_s' or self.buffer_size >= self.max_bytes:
            return x

        x_stat = x.detach().cpu().numpy()
        self.buffer.append(x_stat)
        self.buffer_size += x_stat.nbytes

        return x

    def flush(self):
        if len(self.buffer) > 0:
            data = np.concatenate(self.buffer, axis=0)
            file_path = os.path.join(self.save_dir, f"{self.place}_{self.step:03d}.npy")
            np.save(file_path, data)
            print(f"[MyTestPlace] Saved {file_path} shape={data.shape}")
    
            self.buffer.clear()
            self.buffer_size = 0
            self.step += 1

def final_flush(model):
    for module in model.modules():
        if isinstance(module, MyTestPlace):
            module.flush()

class TestNeuron(nn.Module):
    def __init__(self, place = None, percent = None):
        super(TestNeuron, self).__init__()
        self.place = place
        self.percent = percent
        self.num = 0
        self.scale_p = torch.nn.Parameter(torch.FloatTensor([0.]))
        self.scale_n = torch.nn.Parameter(torch.FloatTensor([0.]))

    def forward(self, x):
        x2 = x.reshape(-1)
        N = x2.numel()
        k = int((1 - self.percent) * N)

        if k == 0:
            threshold = torch.max(x2).item()
            self.scale_p = torch.nn.Parameter((self.scale_p * self.num + threshold) / (self.num + 1))
            threshold = -torch.min(x2).item()
            self.scale_n = torch.nn.Parameter((self.scale_n * self.num + threshold) / (self.num + 1))
            self.num += 1
            return x

        threshold = torch.topk(x2, k, largest = True).values[-1].item()
        self.scale_p = torch.nn.Parameter((self.scale_p * self.num + threshold) / (self.num + 1))
        threshold = -torch.topk(x2, k, largest = False).values[-1].item()
        self.scale_n = torch.nn.Parameter((self.scale_n * self.num + threshold)/(self.num + 1))
        self.num += 1
        return x
        
    def reset(self):
        pass

def replace_test_by_testneuron(model,percent=None):
    for name, module in model._modules.items():
        if hasattr(module, "_modules"):
            model._modules[name] = replace_test_by_testneuron(module,percent)
        if module.__class__.__name__.lower() == 'mytestplace':
            model._modules[name] = TestNeuron(place=model._modules[name].place, percent=percent)
    return model

class exp_comp_neuron(nn.Module):
    def __init__(self, func, *args, **keywords):
        super(exp_comp_neuron, self).__init__(*args, **keywords)
        self.tot = None
        self.func = func
        self.t = 0
        
    def forward(self, x):
        # self.log_input_digit(x)
        if self.tot == None:
            self.tot = x.clone()
            self.t += 1
            return self.func(self.tot)
        else:
            last = self.func(self.tot / self.t) * self.t
            self.tot += x
            self.t += 1
            return self.func(self.tot / self.t) * self.t - last # now - last

    def reset(self):
        self.tot = None
        self.t = 0

def replace_nonlinear_by_neuron(model):
    for name, module in model._modules.items():
        if hasattr(module, "_modules"):
            model._modules[name] = replace_nonlinear_by_neuron(module)
        if 'softmax' in module.__class__.__name__.lower() or 'gelu' in module.__class__.__name__.lower() or 'layernorm' in module.__class__.__name__.lower():
            model._modules[name] = exp_comp_neuron(func=model._modules[name])
    return model

class MyAt(nn.Module):
    def __init__(self):
        super(MyAt, self).__init__()
    def forward(self, x, y):
        return x @ y

class AtNeuron(nn.Module):
    def __init__(self):
        super(AtNeuron, self).__init__()
        self.tot_a = None 
        self.tot_b = None 
        self.tot_t = None 
        self.t = 0 
        
    def forward(self, x, y):
        if self.t == 0:
            self.tot_a = x 
            self.tot_b = y 
            self.tot_t = x @ y 
            self.t = 1 
            return x @ y
        else:
            self.tot_t += x @ y + x @ self.tot_b + self.tot_a @ y
            self.tot_a += x
            self.tot_b += y
            self.t += 1
            return (x @ self.tot_b + self.tot_a @ y - x @ y) / (self.t - 1) - self.tot_t / (self.t * (self.t - 1))

    def zero_count(self, x):
        return (x == 0).sum() / x.numel()
    
    def reset(self):
        self.tot_a = None
        self.tot_b = None
        self.tot_t = None
        self.t=0
    
def replace_at_by_neuron(model):
    for name, module in model._modules.items():
        if hasattr(module, "_modules"):
            model._modules[name] = replace_at_by_neuron(module)
        if module.__class__.__name__.lower()=="myat":
            model._modules[name] = AtNeuron()
    return model

def reset_net(model):#initialize all neurons
    for name, module in model._modules.items():
        if hasattr(module,"_modules"):
            reset_net(module)
        if 'neuron' in module.__class__.__name__.lower():
            module.reset()
    return model

class BaseMonitor:
    def __init__(self):
        self.hooks = []
        self.monitored_layers = []
        self.records = []
        self.name_records_index = {}
        self._enable = True
    def __getitem__(self, i):
        if isinstance(i, int):
            return self.records[i]
        elif isinstance(i, str):
            y = []
            for index in self.name_records_index[i]:
                y.append(self.records[index])
            return y
        else:
            raise ValueError(i)
    def clear_recorded_data(self):
        self.records.clear()
        for k, v in self.name_records_index.items():
            v.clear()
    def enable(self):
        self._enable = True
    def disable(self):
        self._enable = False
    def is_enable(self):
        return self._enable
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
    def __del__(self):
        self.remove_hooks()

class SOPMonitor(BaseMonitor):
    def __init__(self, net: nn.Module, type=1):
        super().__init__()
        for name, m in net.named_modules():
            # calculated energy consumption
            if type == 1 and m.__class__.__name__.lower() == 'linear':
                self.monitored_layers.append(name)
                self.name_records_index[name] = []
                self.hooks.append(m.register_forward_hook(self.create_hook1(name)))
            if type == 1 and m.__class__.__name__.lower() == 'atneuron':
                self.monitored_layers.append(name)
                self.name_records_index[name] = []
                self.hooks.append(m.register_forward_hook(self.create_hook1_2(name)))
            # calculate fire rate of neurons
            if type == 2 and (m.__class__.__name__.lower() == 'sfneuron'):
                self.monitored_layers.append(name)
                self.name_records_index[name] = []
                self.hooks.append(m.register_forward_hook(self.create_hook2(name)))

    def find_scale(self, x):
        xmax = x.max().clone()
        x = torch.where(
            x < xmax / 263, # 16, 256, 263 for linear, exp, step
            0.0,
            x,
        )
        mask = x > 0
        scale = x[mask].min()
        return scale
        
    def quantize_by_min_scale(self, x):
        x_quantized = torch.zeros_like(x)
        scale_p = torch.tensor([0.0]).to(x.device)
        scale_n = torch.tensor([0.0]).to(x.device)
    
        pos_mask = x > 0
        neg_mask = x < 0

        if torch.any(pos_mask):
            scale_p = self.find_scale(x[pos_mask])
            x_quantized = torch.where(
                pos_mask,
                torch.round(x / scale_p),
                x_quantized,
            )

        if torch.any(neg_mask):
            scale_n = self.find_scale(-x[neg_mask])
            x_quantized = torch.where(
                neg_mask,
                torch.round(-x / scale_n),
                x_quantized,
            )

        return x_quantized
    
    def cal_sop1(self, x: Tensor, m: nn.Linear):
        #y = torch.zeros_like(x).to(torch.float64)
        #y[x != 0] = 1
        y = self.quantize_by_min_scale(x).to(torch.float64)
        
        weight = torch.ones_like(m.weight).to(torch.float64)
        bias = torch.zeros_like(m.bias).to(torch.float64)
        with torch.no_grad():
            out0 = (torch.nn.functional.linear(y, weight, bias)).sum()
            sum0 = (torch.nn.functional.linear(torch.ones_like(y), weight, bias)).sum()
            return out0, sum0

    def create_hook1(self, name):
        def hook1(m: nn.Linear, x: Tensor, y):
            if self.is_enable():
                self.name_records_index[name].append(self.records.__len__())
                self.records.append(self.cal_sop1(x[0], m))
        return hook1
                             
    def cal_sop1_2(self, A: Tensor,B: Tensor, m: AtNeuron):
        tmp_A = torch.ones_like(A).to(torch.float64)
        tmp_B = torch.ones_like(B).to(torch.float64)
        sum0 = (tmp_A @ tmp_B).sum()
        
        #tmp_A = torch.zeros_like(A).to(torch.float64)
        tmp_B = torch.zeros_like(B).to(torch.float64)
        #tmp_A[A != 0]=1
        tmp_B[B != 0]=1
        tmp_A = self.quantize_by_min_scale(A).to(torch.float64)
        #tmp_B = self.quantize_by_min_scale(B).to(torch.float64)
        
        #tmp_As = torch.zeros_like(m.tot_a).to(torch.float64)
        #tmp_As[m.tot_a != 0] = 1
        tmp_As = self.quantize_by_min_scale(m.tot_a).to(torch.float64)
        tmp_Bs = torch.zeros_like(m.tot_b).to(torch.float64)
        tmp_Bs[m.tot_b != 0] = 1
        #tmp_Bs = self.quantize_by_min_scale(m.tot_b).to(torch.float64)
        
        out01 = (tmp_A @ tmp_B).sum()
        out02 = (tmp_A @ tmp_Bs).sum()
        out03 = (tmp_As @ tmp_B).sum()
        out0 = out01 + out02 + out03
        return out0, sum0

    def create_hook1_2(self, name):
        def hook1_2(m: AtNeuron, x, y):
            if self.is_enable():
                self.name_records_index[name].append(self.records.__len__())
                self.records.append(self.cal_sop1_2(x[0], x[1], m))
        return hook1_2                   
    
    def cal_sop2(self, x):
        num_elements = x[0].numel()
        tmp = []
        for index, i in enumerate(x):
            tmp.append((i!= 0).sum())
        return sum(tmp), num_elements, tmp

    def create_hook2(self, name):
        def hook2(m: SFNeuron, x: Tensor, y):
            if self.is_enable():
                self.name_records_index[name].append(self.records.__len__())
                self.records.append(self.cal_sop2(y))
        return hook2