import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os
import numpy as np
import logging
from torch.nn.modules import loss
from models.layers import LIFSpike

def seed_all(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    

def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    return logger


def TET_loss(outputs, labels, criterion, means, lamb):
    T = outputs.size(1)
    Loss_es = 0
    for t in range(T):
        Loss_es += criterion(outputs[:, t, ...], labels)
    Loss_es = Loss_es / T # L_TET
    if lamb != 0:
        MMDLoss = torch.nn.MSELoss()
        y = torch.zeros_like(outputs).fill_(means)
        Loss_mmd = MMDLoss(outputs, y) # L_mse
    else:
        Loss_mmd = 0
    return (1 - lamb) * Loss_es + lamb * Loss_mmd # L_Total

def att_loss_r2b(student_mem, teacher_mem):
    assert student_mem.shape == teacher_mem.shape, \
        f"Shape mismatch: student {student_mem.shape} vs teacher {teacher_mem.shape}"

    device = student_mem.device
    teacher_mem = teacher_mem.to(device)

    B = student_mem.size(0)
    stu_mem = student_mem.mean(1).view(B,-1)
    tea_mem = teacher_mem.mean(1).view(B,-1)   # B,CHW
    stu_gram = torch.mm(stu_mem, stu_mem.t())
    tea_gram = torch.mm(tea_mem, tea_mem.t())

    # 计算L2范数归一化
    student_mem_norm = stu_gram / torch.norm(stu_gram, p=2)
    teacher_mem_norm = tea_gram / torch.norm(tea_gram, p=2)

    # 计算差异的R2范数
    tmp = student_mem_norm - teacher_mem_norm
    loss = torch.norm(tmp, p=2)
    return loss

# def att_loss_r2b(student_mem, teacher_mem):
#     assert student_mem.shape == teacher_mem.shape, \
#         f"Shape mismatch: student {student_mem.shape} vs teacher {teacher_mem.shape}"
#
#     device = student_mem.device
#     teacher_mem = teacher_mem.to(device)
#
#     # 计算L2范数归一化
#     student_mem_norm = student_mem / torch.norm(student_mem, p=2)
#     teacher_mem_norm = teacher_mem / torch.norm(teacher_mem, p=2)
#
#     # 计算差异的R2范数
#     tmp = student_mem_norm - teacher_mem_norm
#     loss = torch.norm(tmp, p=2)
#     return loss

def FR_loss(tea_mem, stu_mem):

    assert len(tea_mem) == len(stu_mem), "教师和学生特征表示长度必须相同"

    criterion = nn.MSELoss()
    total_loss = 0

    for t_feat, s_feat in zip(tea_mem, stu_mem):
        loss = criterion(s_feat, t_feat)  # 学生特征作为预测值,教师特征作为目标
        total_loss += loss

    avg_loss = total_loss / len(tea_mem)
    return avg_loss

def Mem_loss(tea_mem, stu_mem):

    assert len(tea_mem) == len(stu_mem), \
            "Student and teacher models must have the same number of LIF neurons"

    # 计算膜电位的蒸馏损失
    fea_loss = 0.0  # 初始化损失

    for student_lif, teacher_lif in zip(stu_mem, tea_mem):
        # print('1',student_lif.shape)    # BTCHW

        if len(student_lif)>0 and len(teacher_lif)>0:
            fea_loss += att_loss_r2b(student_lif, teacher_lif)
        else:
            return None

    return fea_loss


class BlockwiseDistillation:
    def __init__(self, teacher_model, student_model):
        self.teacher = teacher_model.eval()
        self.student = student_model

        # 冻结教师模型参数
        for param in self.teacher.parameters():
            param.requires_grad = False

        # 存储中间特征的hook
        self.teacher_features = {}
        self.student_features = {}

        # 注册hook用于获取中间特征
        self._register_hooks()

    def att_loss_r2b(self, Q_s, Q_t):
        device = Q_s.device
        Q_t = Q_t.to(device)
        Q_s_norm = Q_s / torch.norm(Q_s, p=2)
        Q_t_norm = Q_t / torch.norm(Q_t, p=2)
        tmp = Q_s_norm - Q_t_norm
        loss = torch.norm(tmp, p=2)
        return loss

    def _register_hooks(self):
        def get_teacher_hook(name):
            def hook(model, input, output):
                self.teacher_features[name] = output

            return hook

        def get_student_hook(name):
            def hook(model, input, output):
                self.student_features[name] = output

            return hook

        # 注册教师模型的hook
        self.teacher.module.layer1.register_forward_hook(get_teacher_hook('layer1'))
        self.teacher.module.layer2.register_forward_hook(get_teacher_hook('layer2'))
        self.teacher.module.layer3.register_forward_hook(get_teacher_hook('layer3'))

        # 注册学生模型的hook
        self.student.module.layer1.register_forward_hook(get_student_hook('layer1'))
        self.student.module.layer2.register_forward_hook(get_student_hook('layer2'))
        self.student.module.layer3.register_forward_hook(get_student_hook('layer3'))

    def compute_distillation_loss(self, student_output, teacher_output, labels):
        device = student_output.device
        # 计算蒸馏损失
        distillation_loss = torch.tensor(0.0, device=device)

        # 计算每个block的特征蒸馏损失
        for layer_name in ['layer1', 'layer2', 'layer3']:
            student_feat = self.student_features[layer_name]        # TBCHW
            teacher_feat = self.teacher_features[layer_name]        # BCHW

            # 对每个时间步进行membrane potential对齐
            if student_feat.dim() == 5:  # [T, B, C, H, W]
                T = student_feat.size(0)
                distillation_loss_t = 0.0
                for t in range(T):
                    student_feat_t = student_feat[t]  # [B, C, H, W]
                    distillation_loss_t += self.att_loss_r2b(student_feat_t, teacher_feat)
                distillation_loss_t = torch.tensor(distillation_loss_t,device=device)
                distillation_loss += distillation_loss_t / T

        return distillation_loss.item()

class DistributionLoss(loss._Loss):
    """The KL-Divergence loss for the binary student model and real teacher output.

    output must be a pair of (model_output, real_output), both NxC tensors.
    The rows of real_output must all add up to one (probability scores);
    however, model_output must be the pre-softmax output of the network."""

    def forward(self, model_output, real_output):

        self.size_average = True

        # Target is ignored at training time. Loss is defined as KL divergence
        # between the model output and the refined labels.
        if real_output.requires_grad:
            raise ValueError("real network output should not require gradients.")

        model_output_log_prob = F.log_softmax(model_output, dim=1)
        real_output_soft = F.softmax(real_output, dim=1)
        del model_output, real_output

        # Loss is -dot(model_output_log_prob, real_output). Prepare tensors
        # for batch matrix multiplicatio
        real_output_soft = real_output_soft.unsqueeze(1)
        model_output_log_prob = model_output_log_prob.unsqueeze(2)

        # Compute the loss, and average/sum for the batch.
        cross_entropy_loss = -torch.bmm(real_output_soft, model_output_log_prob)
        if self.size_average:
             cross_entropy_loss = cross_entropy_loss.mean()
        else:
             cross_entropy_loss = cross_entropy_loss.sum()
        # Return a pair of (loss_output, model_output). Model output will be
        # used for top-1 and top-5 evaluation.
        # model_output_log_prob = model_output_log_prob.squeeze(2)
        return cross_entropy_loss

class MembraneDistiller:

    def __init__(self, student_model, teacher_model):
        self.student_model = student_model
        self.teacher_model = teacher_model

    def collect_lif_neurons(self, model):
        """收集模型中的所有LIF神经元"""
        lif_neurons = []
        for module in model.modules():
            if isinstance(module, LIFSpike):
                lif_neurons.append(module)
        return lif_neurons

    def att_loss_r2b(self, student_mem, teacher_mem):
        assert student_mem.shape[0] == teacher_mem.shape[0], \
            f"Shape mismatch: student {student_mem.shape} vs teacher {teacher_mem.shape}"

        device = student_mem.device
        teacher_mem = teacher_mem.to(device)

        # 计算L2范数归一化
        student_mem_norm = student_mem / torch.norm(student_mem, p=2)
        teacher_mem_norm = teacher_mem / torch.norm(teacher_mem, p=2)

        # 计算差异的R2范数
        tmp = student_mem_norm - teacher_mem_norm
        loss = torch.norm(tmp, p=2)
        return loss

    def compute_distillation_loss(self, criterion, logis_stu, logis_tea):
        """计算所有对应LIF神经元之间的蒸馏损失"""
        device = logis_tea.device
        student_neurons = self.collect_lif_neurons(self.student_model)
        teacher_neurons = self.collect_lif_neurons(self.teacher_model)

        assert len(student_neurons) == len(teacher_neurons), \
            "Student and teacher models must have the same number of LIF neurons"

        # 计算膜电位的蒸馏损失
        fea_loss = 0.0  # 初始化损失
        for student_lif, teacher_lif in zip(student_neurons, teacher_neurons):
            student_mem_history = student_lif.get_mem_history()
            teacher_mem_history = teacher_lif.get_mem_history()

            if student_mem_history is not None and teacher_mem_history is not None:
                student_mem_history = student_mem_history.to(device)
                teacher_mem_history = teacher_mem_history.to(device)
                fea_loss += self.att_loss_r2b(student_mem_history, teacher_mem_history)

        # 计算logits损失
        T = logis_stu.size(1)
        logis_loss = torch.mean(
            torch.stack([criterion(logis_stu[:, t], logis_tea[:, t]) for t in range(T)]), dim=0
        )

        # 重置所有神经元的膜电位
        for student_lif, teacher_lif in zip(student_neurons, teacher_neurons):
            student_lif.reset_mem()
            teacher_lif.reset_mem()

        # 返回蒸馏损失总和
        return logis_loss + fea_loss

def split_weights(net):
    """split network weights into to categlories,
    one are weights in conv layer and linear layer,
    others are other learnable paramters(conv bias,
    bn weights, bn bias, linear bias)

    Args:
        net: network architecture

    Returns:
        a dictionary of params splite into to categlories
    """

    decay = []
    no_decay = []

    for m in net.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            decay.append(m.weight)

            if hasattr(m, 'sign'):
                no_decay.append(m.sign)

            if m.bias is not None:
                no_decay.append(m.bias)

        else:
            if hasattr(m, 'weight'):
                no_decay.append(m.weight)
            if hasattr(m, 'bias'):
                no_decay.append(m.bias)
            if hasattr(m, 'clip_val'):
                no_decay.append(m.clip_val)
            # if hasattr(m, 'thresh'):
            #     no_decay.append(m.thresh)

    assert len(list(net.parameters())) == len(decay) + len(no_decay)

    return [dict(params=decay), dict(params=no_decay, weight_decay=0)]