import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda import amp

from scipy.spatial.distance import pdist, squareform

from spikingjelly.clock_driven import functional, surrogate, layer, neuron
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from torch.utils.data import DataLoader, RandomSampler, Subset
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
import numpy as np
import math
from sklearn.metrics import mutual_info_score
from sklearn.preprocessing import KBinsDiscretizer
import seaborn as sns
import matplotlib.pyplot as plt
from layers import VotingLayer, TCJA

_seed_ = 2020
# torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# np.random.seed(_seed_)




def compute_distances(x):

    x_norm = (x ** 2).sum(1).view(-1, 1)
    x_t = torch.transpose(x, 0, 1)
    x_t_norm = x_norm.view(1, -1)
    dist = x_norm + x_t_norm - 2.0 * torch.mm(x, x_t)
    dist = torch.clamp(dist, 0, np.inf)

    return dist

def KDE_IXT_estimation(logvar_t, mean_t):

    n_batch, d= mean_t.shape
    var = torch.exp(logvar_t) + 1e-10  # to avoid 0's in the log

    # calculation of the constant
    normalization_constant = math.log(n_batch)

    # calculation of the elements contribution
    dist = compute_distances(mean_t)
    distance_contribution = - torch.mean(torch.logsumexp(input=- 0.5 * dist / var, dim=1))

    # mutual information calculation (natts)
    I_XT = normalization_constant + distance_contribution

    return I_XT


class CextNet(nn.Module):  # TODO: kernel_size parameter passing
    def __init__(self, channels: int):
        super().__init__()
        conv = []

        conv.extend(CextNet.conv3x3(2, channels))
        conv.append(layer.SeqToANNContainer(nn.MaxPool2d(2, 2)))

        conv.extend(CextNet.conv3x3(channels, channels))
        conv.append(layer.SeqToANNContainer(nn.MaxPool2d(2, 2)))

        conv.extend(CextNet.conv3x3(channels, channels))
        conv.append(layer.SeqToANNContainer(nn.MaxPool2d(2, 2)))

        for i in range(2):
            conv.extend(CextNet.conv3x3(channels, channels))  # TODO: kernel size must be equal
            conv.append(TCJA(4, 4, 20, 128))
            conv.append(layer.SeqToANNContainer(nn.MaxPool2d(2, 2)))


        self.conv = nn.Sequential(*conv)

        self.fc = nn.Sequential(
            nn.Flatten(2),
            layer.MultiStepDropout(0.5),
            layer.SeqToANNContainer(nn.Linear(channels * 4 * 4, channels * 2 * 2, bias=False)),
            neuron.MultiStepLIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                    backend='cupy'),
            layer.MultiStepDropout(0.5),
            layer.SeqToANNContainer(nn.Linear(channels * 2 * 2, 110, bias=False)),
            neuron.MultiStepLIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True, backend='cupy')
        )
        self.vote = VotingLayer(10)
        self.logvar_t = -1.0
        self.logvar_t = torch.nn.Parameter(torch.Tensor([self.logvar_t]))

    def get_IXT(self, mean_t):
        '''
        Obtains the mutual information between the iput and the bottleneck variable.
        Parameters:
        - mean_t (Tensor) : deterministic transformation of the input
        '''

        IXT = KDE_IXT_estimation(self.logvar_t, mean_t)  # in natts
        IXT = IXT / np.log(2)  # in bits
        return IXT

     # 得到T即没有经过解码的网络输出

    def encode_features(self, x):

        mean_t = x.view(-1, self.num_flat_features(x))
        return mean_t

    def num_flat_features(self, x):
        size = x.size()[1:]  # x.size返回的是一个元组，size表示截取元组中第二个开始的数字
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def first_spike_times(self, input_tensor):
        # 使用torch.where找到脉冲的位置
        indices = torch.where(input_tensor == 1, torch.arange(input_tensor.size(1)).float().to(input_tensor.device).unsqueeze(0).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(input_tensor), float('inf') * torch.ones_like(input_tensor))
        # 找到每个神经元第一次脉冲的最小索引
        first_spike_times, _ = torch.min(indices, dim=1)
        return first_spike_times

    def get_first_spike_time(self, x):
        batch, time_steps, c, h, w = x.shape
        device = x.device

        # 时间步索引
        time_indices = torch.arange(time_steps, device=device, dtype=torch.float)
        time_indices = time_indices.view(1, time_steps, 1, 1, 1)  # 适配维度

        # 强调较早脉冲的缩放：时间越晚，缩放系数越小
        scaled_x = x * (time_steps - 1 - time_indices)  # 使得早的时间步权重更高

        # 沿时间维度计算softmax权重
        weights = torch.softmax(scaled_x, dim=1)

        # 加权和得到首次脉冲时间的可导近似
        first_spike_time = torch.sum(weights * time_indices, dim=1)

        return first_spike_time

    def renyi_entropy_loss(self, first_spike_time, alpha=1, sigma=0.1, subset_size=1000):
        # 获取设备信息
        device = first_spike_time.device

        # 将sigma转换为与输入数据相同设备的张量
        sigma = torch.tensor(sigma, device=device)

        # 展平时间矩阵并随机采样子集以减少计算量
        samples = first_spike_time.flatten()  # [n]
        n = samples.numel()
        subset_size = min(n, subset_size)

        # 随机选择子集（保持梯度）
        indices = torch.randperm(n, device=device)[:subset_size]
        samples_subset = samples[indices]  # [m]

        # 计算pairwise差异（采用广播）
        x = samples_subset.unsqueeze(1)  # [m, 1]
        y = samples_subset.unsqueeze(0)  # [1, m]
        pairwise_diffs = x - y  # [m, m]

        # 计算高斯核（修正数学常量为张量形式）
        sqrt_2pi = torch.sqrt(2 * torch.tensor(torch.pi, device=device))  # 正确写法
        kernel = torch.exp(-pairwise_diffs ** 2 / (2 * sigma ** 2)) / (sigma * sqrt_2pi)

        # 计算Renyi熵（后续保持不变）
        p = kernel.mean(dim=1)
        # 计算Renyi熵
        if alpha == 1:
            # 当alpha趋近于1时，退化为香农熵
            entropy = -torch.mean(torch.log(p + 1e-10))  # 防止log(0)
        else:
            entropy = (1 / (1 - alpha)) * torch.log(torch.mean(p ** alpha + 1e-10))

        return entropy

    def estimate_entropy(self, x, n_bins=10):
        """估计单变量的熵，使用离散化后计算信息熵"""
        discretizer = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform')
        x_binned = discretizer.fit_transform(x.reshape(-1, 1)).astype(int).flatten()
        p_x = np.bincount(x_binned) / len(x_binned)
        p_x = p_x[p_x > 0]  # 去除0概率项
        return -np.sum(p_x * np.log2(p_x))

    def estimate_joint_entropy(self, X, n_bins=10):
        """估计联合熵，X 是 shape=(n_samples, n_features) 的数据"""
        discretizer = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform')
        X_binned = discretizer.fit_transform(X).astype(int)
        joint_bins = np.ravel_multi_index(X_binned.T, dims=[n_bins] * X.shape[1])
        p_joint = np.bincount(joint_bins) / len(joint_bins)
        p_joint = p_joint[p_joint > 0]
        return -np.sum(p_joint * np.log2(p_joint))

    def total_correlation(self, X, n_bins=10):
        """
        估计总相关性 TC(X) = sum_i H(X_i) - H(X)
        参数:
            X: shape=(n_samples, n_features)
            n_bins: 离散化用的分桶数
        返回:
            TC 值（非负浮点数）
        """
        n_features = X.shape[1]
        marginal_entropy = sum(self.estimate_entropy(X[:, i], n_bins) for i in range(n_features))
        joint_entropy = self.estimate_joint_entropy(X, n_bins)
        return marginal_entropy - joint_entropy

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [N, T, 2, H, W] -> [T, N, 2, H, W]
        # print(self.conv)
        x = self.conv(x)
        # print(x.shape)
        z = x
        z = z.detach()

        spike_matrix = self.get_first_spike_time(x)
        # spike_matrix = spike_matrix.view(20,-1)
        T = self.encode_features(x)
        H = self.renyi_entropy_loss(spike_matrix)
        out_spikes = self.fc(x)  # shape = [T, N, 110]
        out = self.vote(out_spikes.mean(0))
        return out, T, H, z

    @staticmethod
    def conv3x3(in_channels: int, out_channels):
        return [
            layer.SeqToANNContainer(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels)
            ),
            neuron.MultiStepLIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                    backend='cupy'),
        ]


def Gaussian_noise(img, mean, sigma):
    device = img.device
    img = img.cpu()
    img = np.array(img)
    # img = img / 255
    h, w, c = (28, 28, 1)
    N = np.random.normal(loc=mean, scale=sigma, size=(16, 20, 2, 128, 128))
    N = np.repeat(N, c, axis=1)
    img = N + img
    # img = img * 255
    img[img > 255] = 255
    img = torch.tensor(img)
    img = img.cuda('cuda:1')
    img = img.float()
    return img

def fgsm_attack(model, images, labels, epsilon):
    images_copy = images.clone().detach().to(torch.float32).requires_grad_(True)
    outputs, _, _, _ = model(images_copy)
    # outputs = model(images)
    loss = F.mse_loss(outputs, labels)
    model.zero_grad()
    loss.backward()
    grad_sign = images_copy.grad.sign()
    # print('grad_sign = ', grad_sign.mean())
    perturbed_images = images + epsilon * grad_sign
    # perturbed_images = torch.clamp(perturbed_images, 0, 1)
    return perturbed_images

def main():
    parser = argparse.ArgumentParser(description='Classify DVS128 Gesture')
    parser.add_argument('-T', default=16, type=int, help='simulating time-steps')
    parser.add_argument('-device', default='cuda:1s', help='device')
    parser.add_argument('-b', default=16, type=int, help='batch size')
    parser.add_argument('-epochs', default=64, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-channels', default=128, type=int, help='channels of Conv2d in SNN')
    parser.add_argument('-data_dir', type=str, help='root dir of DVS128 Gesture dataset')
    parser.add_argument('-out_dir', type=str, help='root dir for saving logs and checkpoint')

    parser.add_argument('-resume', type=str, help='resume from the checkpoint path')
    parser.add_argument('-amp', action='store_true', help='automatic mixed precision training')

    parser.add_argument('-opt', type=str, help='use which optimizer. SDG or Adam')
    parser.add_argument('-lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
    parser.add_argument('-lr_scheduler', default='CosALR', type=str, help='use which schedule. StepLR or CosALR')
    parser.add_argument('-step_size', default=32, type=float, help='step_size for StepLR')
    parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR')
    parser.add_argument('-T_max', default=32, type=int, help='T_max for CosineAnnealingLR')
    parser.add_argument('-beta', default=0.001, type=float, help='beta for the loss function')
    parser.add_argument('-alpha', default=0.001, type=float, help='alpha for the loss function')
    parser.add_argument('-sub_trainset_ratio', default=0.5, type=float, help='ratio of sub-trainset')

    args = parser.parse_args()
    print(args)

    net = CextNet(channels=args.channels)
    # RenyiLoss = RenyiEntropyLoss()
    print(net)

    net.to(args.device)
    # mon = Monitor(net, args.device, 'torch')
    # mon.enable()

    optimizer = None
    if args.opt == 'SGD':
        optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
    elif args.opt == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(args.opt)

    lr_scheduler = None
    if args.lr_scheduler == 'StepLR':
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    elif args.lr_scheduler == 'CosALR':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.T_max)
    else:
        raise NotImplementedError(args.lr_scheduler)

    train_set = DVS128Gesture(args.data_dir, train=True, data_type='frame', split_by='number', frames_number=args.T)
    test_set = DVS128Gesture(args.data_dir, train=False, data_type='frame', split_by='number', frames_number=args.T)

    print(f'Sub-trainset ratio: {args.sub_trainset_ratio}')
    sub_trainset_size = int(args.sub_trainset_ratio * len(train_set))
    indices = list(range(len(train_set)))
    sampler = RandomSampler(indices, replacement=False, num_samples=sub_trainset_size)
    subset_indices = list(sampler)
    train_set = Subset(train_set, subset_indices)

    train_data_loader = DataLoader(
        dataset=train_set,
        batch_size=args.b,
        shuffle=True,
        num_workers=args.j,
        drop_last=True,
        pin_memory=True)

    test_data_loader = DataLoader(
        dataset=test_set,
        batch_size=args.b,
        shuffle=False,
        num_workers=args.j,
        drop_last=False,
        pin_memory=True)

    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = 0

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']

    out_dir = os.path.join(args.out_dir, f'T_{args.T}_b_{args.b}_c_{args.channels}_{args.opt}_lr_{args.lr}_max')
    if args.lr_scheduler == 'CosALR':
        out_dir += f'CosALR_{args.T_max}'
    elif args.lr_scheduler == 'StepLR':
        out_dir += f'StepLR_{args.step_size}_{args.gamma}'
    else:
        raise NotImplementedError(args.lr_scheduler)

    if args.amp:
        out_dir += '_amp'

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')

    with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    writer = SummaryWriter(os.path.join(out_dir, 'dvsg_logs'), purge_step=start_epoch)

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for frame, label in train_data_loader:
            optimizer.zero_grad()
            frame = frame.float().to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 11).float()
            if args.amp:
                with amp.autocast():
                    out_fr, T, H_alpha_X, z = net(frame)
                    I_XT = net.get_IXT(T)
                    # loss = F.mse_loss(out_fr, label_onehot)
                    # loss = (F.mse_loss(out_fr, label_onehot))**3 + args.beta * (I_XT)**3
                    loss = F.mse_loss(out_fr, label_onehot) + args.beta * I_XT + args.alpha * H_alpha_X

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                out_fr, T, H_alpha_X, z = net(frame)
                I_XT = net.get_IXT(T)
                # H_alpha_X = H_alpha_renyi(X_spikes_seq)
                loss = F.mse_loss(out_fr, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)

        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0

        for frame, label in test_data_loader:
            frame = frame.float().to(args.device)
            # frame = Gaussian_noise(frame, 0, 10)
            frame = black_box_gaussian_noise_attack(frame, 1)

            label = label.to(args.device)
            label_onehot = F.one_hot(label, 11).float()
            # frame = fgsm_attack(net, frame, label_onehot, epsilon=1)

            with torch.no_grad():
                out_fr, T, _, _ = net(frame)
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)


        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(
            f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, total_time={time.time() - start_time}')

        with open('training_accuracy_log_002.txt', 'a') as f:
            f.write(f'{max_test_acc}\n')


if __name__ == '__main__':
    log_file = 'training_accuracy_log_002.txt'
    if not os.path.exists(log_file):
        with open(log_file, 'w') as f:
            f.write('Training Accuracy Log\n')
    main()
