import math
import time
import torch
import torch.nn as nn
import torchvision
import numpy as np
import argparse
import torch.backends.cudnn as cudnn
import os
import torch.nn.functional as F

# distributed training
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel as DDP

from modules import get_resnet, get_resnet_spiking, get_vgg, get_vgg_spiking, SynctdBatchNorm, modify_resnet_model
from loss import BarlowTwinsLoss, BarlowTwinsTemporalLoss
from modules.spike_layer import MixedLIF, LIFt

from BartonTwins import BartonTwins, BartonTwins_imagenet
from BartonTwins_spiking import BartonTwinsSpiking, BartonTwinsSpiking_imagenet
from utils import yaml_config_hook
from model import load_optimizer, save_model, save_model_top_k
from modules.transformations import DataTransforms, DataTransforms_imagenet

# TensorBoard
from torch.utils.tensorboard import SummaryWriter
# amp
from torch.cuda.amp import GradScaler, autocast
# spikingjelly
# from spikingjelly.activation_based import functional


def compare_gradients_improved(grads_1, grads_2, args):
    """改进版本"""
    print("=== Improved Version ===")

    total_params = 0
    total_kl = 0
    total_cos_sim = 0
    layer_count = 0

    for name in grads_1:
        if name in grads_2:
            g1 = grads_1[name].detach()
            g2 = grads_2[name].detach()

            # 检查形状匹配
            if g1.shape != g2.shape:
                print(f"{name}: Shape mismatch: {g1.shape} vs {g2.shape}")
                continue

            # 展平用于某些计算
            g1_flat = g1.flatten()
            g2_flat = g2.flatten()

            # 1. 余弦相似度 (正确)
            cos_sim = F.cosine_similarity(g1_flat.unsqueeze(0), g2_flat.unsqueeze(0)).item()

            # 2. L2范数差异 (正确)
            l2_diff = torch.norm(g1_flat - g2_flat).item()

            # 3. 相对L2差异 (更有意义)
            g1_norm = torch.norm(g1_flat).item()
            g2_norm = torch.norm(g2_flat).item()
            rel_l2_diff = l2_diff / (max(g1_norm, g2_norm) + 1e-8)

            # 4. 均值和标准差差异
            mean_diff = torch.abs(g1.mean() - g2.mean()).item()
            std_diff = torch.abs(g1.std() - g2.std()).item()

            # 5. 最大差异
            max_diff = torch.max(torch.abs(g1 - g2)).item()

            # 6. 改进的KL散度计算（如果需要的话）
            # 将梯度转换为非负概率分布
            # g1_pos = torch.abs(g1_flat) + 1e-8
            # g2_pos = torch.abs(g2_flat) + 1e-8
            #
            # p = g1_pos / g1_pos.sum()
            # q = g2_pos / g2_pos.sum()
            # 构造统一的 bin 范围（建议用全局 min/max，确保对齐）
            min_val = min(g1.min(), g2.min()).item()
            max_val = max(g1.max(), g2.max()).item()
            bins = torch.linspace(min_val, max_val, steps=50)  # 你可以调整 bins 数量

            # 计算直方图，并归一化为概率分布
            hist1 = torch.histc(g1, bins=bins.numel(), min=min_val, max=max_val)
            hist2 = torch.histc(g2, bins=bins.numel(), min=min_val, max=max_val)

            p = hist1 / hist1.sum()
            q = hist2 / hist2.sum()

            # 加个 small epsilon，避免 log(0)
            eps = 1e-8
            p = p + eps
            q = q + eps
            # KL(P||Q)
            kl_div = torch.sum(p * torch.log(p / q)).item()

            # 对称KL散度 (Jensen-Shannon散度的基础)
            kl_sym = 0.5 * (torch.sum(p * torch.log(p / q)) + torch.sum(q * torch.log(q / p))).item()

            # 7. 皮尔逊相关系数
            if g1_flat.numel() > 1:
                corr_coef = torch.corrcoef(torch.stack([g1_flat, g2_flat]))[0, 1].item()
            else:
                corr_coef = float('nan')

            # 累计统计
            total_params += g1_flat.numel()
            total_kl += kl_div
            total_cos_sim += cos_sim
            layer_count += 1

            if hasattr(args, 'rank') and args.rank == 0:
                print(f"{name}:")
                print(f"  L2_diff: {l2_diff:.4e}, Rel_L2: {rel_l2_diff:.4f}")
                print(f"  CosSim: {cos_sim:.4f}, Corr: {corr_coef:.4f}")
                print(f"  Mean_diff: {mean_diff:.4e}, Std_diff: {std_diff:.4e}")
                print(f"  Max_diff: {max_diff:.4e}")
                print(f"  KL_div: {kl_div:.4f}, KL_sym: {kl_sym:.4f}")
                print()

    # 全局统计
    if layer_count > 0:
        avg_cos_sim = total_cos_sim / layer_count
        global_kl_div = total_kl / layer_count

        if hasattr(args, 'rank') and args.rank == 0:
            print(f"Global Statistics:")
            print(f"  Total parameters: {total_params}")
            print(f"  Average cosine similarity: {avg_cos_sim:.4f}")
            print(f"  Global kl: {global_kl_div:.4e}")


def adjust_learning_rate(optimizer, epoch, args):
    """Decays the learning rate with half-cycle cosine after warmup"""
    warmup_epochs = 10
    if epoch < warmup_epochs:
        lr = args.lr * epoch / warmup_epochs
        # lr = args.lr
    else:
        lr = 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (1000 - warmup_epochs)))
    # for param_group in optimizer.param_groups:
    #     param_group['lr'] = lr * args.lr
    optimizer.param_groups[0]['lr'] = lr * args.lr
    optimizer.param_groups[1]['lr'] = lr * args.lr_bias


def train(args, train_loader, model, criterion, optimizer, writer, scaler):
    loss_epoch = 0
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda(args.device, non_blocking=True)
        x_j = x_j.cuda(args.device, non_blocking=True)

        adjust_learning_rate(optimizer, args.current_epoch + step / len(train_loader), args)

        # 1. trail-1 (spike path) forward
        # model.zero_grad()
        # with autocast():
        #     h_i, h_j, z_i, z_j = model(x_i, x_j)
        #     if args.spiking and not args.temporal_loss:
        #         z_i = z_i.mean(0)
        #         z_j = z_j.mean(0)
        #     loss_1 = criterion(z_i, z_j.detach())
        #
        # scaler.scale(loss_1).backward(retain_graph=True)
        # grads_1 = {name: p.grad.detach().clone() for name, p in model.named_parameters() if p.grad is not None}
        #
        # # 2. trail-2 (ReLU-like path) forward
        # model.zero_grad()
        # with autocast():
        #     h_i, h_j, z_i, z_j = model(x_i, x_j)
        #     if args.spiking and not args.temporal_loss:
        #         z_i = z_i.mean(0)
        #         z_j = z_j.mean(0)
        #     loss_2 = criterion(z_i.detach(), z_j)  # dummy loss for backprop



        with autocast():
            h_i, h_j, z_i, z_j = model(x_i, x_i)
            # spiking but not temporal loss
            if args.spiking:
                if not args.temporal_loss:
                    z_i = z_i.mean(0)
                    z_j = z_j.mean(0)
            loss = criterion(z_i, z_j)
            loss_1 = criterion(z_i, z_j.detach())
            loss_2 = criterion(z_i.detach(), z_j)

        scaler.scale(loss_1).backward(retain_graph=True)
        grads_1 = {name: p.grad.detach().clone() for name, p in model.named_parameters() if p.grad is not None}

        model.zero_grad()
        scaler.scale(loss_2).backward(retain_graph=True)
        grads_2 = {name: p.grad.detach().clone() for name, p in model.named_parameters() if p.grad is not None}

        compare_gradients_improved(grads_1, grads_2, args)

        model.zero_grad()
        # use amp loss scaler
        scaler.scale(loss).backward()  # gradient scale

        # gradient clipping
        if not args.clip_grad_norm=="None":
            scaler.unscale_(optimizer)  # unscale
            torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.clip_grad_norm))

        scaler.step(optimizer)
        scaler.update()

        # functional.reset_net(model)

        # use all_reduce to store loss correctly
        # if dist.is_available() and dist.is_initialized():
        #     loss = loss.data.clone()
        #     dist.all_reduce(loss.div_(dist.get_world_size()))

        if args.rank == 0 and step % 10 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        if args.rank == 0:
            writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
            args.global_step += 1

        loss_epoch += loss.item()
    return loss_epoch

def main(gpu, args):
    args.rank = args.nr * args.gpus + gpu
    args.device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")

    if args.world_size > 1:
        dist.init_process_group("nccl", init_method="env://", rank=args.rank, world_size=args.world_size)
        torch.cuda.set_device(gpu)

    torch.backends.cudnn.benchmark = True
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    print("Start Dataset Verification")
    if args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=32),
        )
    elif args.dataset == "CIFAR100":
        train_dataset = torchvision.datasets.CIFAR100(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=32),
        )
    elif args.dataset == "Tiny-ImageNet":
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(args.dataset_dir, 'train'),
            transform=DataTransforms_imagenet(size=args.image_size),
        )
    elif args.dataset == "ImageNet":
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(args.dataset_dir, 'train'),
            transform=DataTransforms_imagenet(size=args.image_size),
        )
    else:
        raise NotImplementedError
    print ("Dataset Verification Finished")

    if args.world_size > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True
        )
    else:
        train_sampler = None

    dataloader_train_ssl = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.num_workers,
        sampler=train_sampler,
        pin_memory=args.pin_memory
    )

    # initialize active function of lif
    if args.act_func == 'MixedLIF':
        Act_func = MixedLIF
    else:
        Act_func = LIFt

    # initialize ResNet, encoder is resnet/resnet_snn
    if args.spiking:
        if "resnet" in args.model:
            backbone = get_resnet_spiking(args.model, args.timestep, args.sync_norm, Act_func, args.n_classes)
            n_features = backbone.fc.in_features  # get dimensions of fc layer
            backbone.fc = nn.Identity()
            # convert to cifar-fitted structure
            if "CIFAR" in args.dataset:
                backbone = modify_resnet_model(backbone)

        elif "vgg" in args.model:
            backbone = get_vgg_spiking(args.model, args.timestep, args.sync_norm, Act_func, args.n_classes)
            n_features = backbone.fc[0].in_features
            backbone.fc = nn.Identity()
        if "ImageNet" in args.dataset:
            model = BartonTwinsSpiking_imagenet(backbone, in_dim=n_features, out_dim=args.projection_dim, act_func=Act_func,
                                       timestep=args.timestep)
        else:
            model = BartonTwinsSpiking(backbone, in_dim=n_features, out_dim=args.projection_dim, act_func=Act_func,
                                       timestep=args.timestep)
    else:
        if "resnet" in args.model:
            backbone = get_resnet(args.model, args.n_classes)
            n_features = backbone.fc.in_features
            backbone.fc = nn.Identity()
            # convert to cifar-fitted structure
            if "CIFAR" in args.dataset:
                backbone = modify_resnet_model(backbone)

        elif "vgg" in args.model:
            backbone = get_vgg(args.model, args.n_classes)
            n_features = backbone.fc[0].in_features
            backbone.fc = nn.Identity()
        if "ImageNet" in args.dataset:
            model = BartonTwins_imagenet(backbone, in_dim=n_features, out_dim=args.projection_dim)
        else:
            model = BartonTwins(backbone, in_dim=n_features, out_dim=args.projection_dim)
        # model = BartonTwins(backbone, in_dim=n_features, out_dim=args.projection_dim)

    model.to(args.device)
    # amp scaler
    scaler = GradScaler()
    if args.rank == 0:
        print('Using native Torch AMP. Training in mixed precision.')

    param_weights = []
    param_biases = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if param.ndim == 1 or name.endswith(".bias"):
            param_biases.append(param)
        else:
            param_weights.append(param)

    parameters = [
        {'params': param_weights, 'lr': args.lr},
        {'params': param_biases, 'lr': args.lr_bias}
    ]

    # optimizer / loss
    optimizer, scheduler = load_optimizer(args, parameters, len(dataloader_train_ssl))
    if args.temporal_loss and args.spiking:
        criterion = BarlowTwinsTemporalLoss(device=args.device, world_size=args.world_size, cross_temporal=args.cross_temporal,
                                            simplified_loss=args.simplified_loss, proj_dim=args.projection_dim)
    else:
        criterion = BarlowTwinsLoss(device=args.device, world_size=args.world_size, proj_dim=args.projection_dim)
    criterion = criterion.to(args.device)

    # resume
    if args.reload:
        model_fp = os.path.join(
            args.model_path, "checkpoint_epoch_{}.tar".format(args.epoch_num)
        )
        checkpoint = torch.load(model_fp, map_location=args.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])

        print(f"Model and optimizer loaded from checkpoint '{model_fp}' at epoch {checkpoint['epoch']}.")

    # channels_last
    # model = model.to(memory_format=torch.channels_last)

    # DDP
    if args.world_size > 1:
        if not args.spiking:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        else:
            # model = SynctdBatchNorm.convert_sync_tdBatchNorm(model)
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            criterion = torch.nn.SyncBatchNorm.convert_sync_batchnorm(criterion)
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    writer = None
    if args.rank == 0:
        writer = SummaryWriter()

    args.global_step = 0
    args.current_epoch = args.start_epoch
    # ONLY save top k models
    top_models = []
    for epoch in range(args.start_epoch, args.epochs):
        if args.rank == 0:
            print('===================================')
            start_time = time.time()

        if train_sampler is not None:
            train_sampler.set_epoch(epoch)

        lr = optimizer.param_groups[0]["lr"]
        loss_epoch = train(args, dataloader_train_ssl, model, criterion, optimizer, writer, scaler)

        if scheduler:
            scheduler.step()

        if args.rank == 0:
            epoch_time = time.time() - start_time
            # save_model(args, model, optimizer, scaler)
            save_model_top_k(args, model, optimizer, scaler, loss_epoch, top_models)
            writer.add_scalar("Loss/train", loss_epoch / len(dataloader_train_ssl), epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            print(
                f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(dataloader_train_ssl)}\t lr: {round(lr, 7)}\t Time: {epoch_time:.2f}s"
            )
            args.current_epoch += 1

    ## end training
    if args.rank == 0:
        print("Training complete. Top 5 models:")
        for _, epoch, model_path in sorted(top_models, reverse=True):
            print(f"Epoch {epoch}: {model_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    config = yaml_config_hook("./config/config.yaml")
    for k, v in config.items():
        parser.add_argument(f"--{k}", default=v, type=type(v))

    args = parser.parse_args()

    # Master address for distributed data parallel
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "8090"

    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = None
    args.world_size = args.gpus * args.nodes
    args.lr = float(args.lr)
    print(vars(args))

    cudnn.benchmark = True

    if args.world_size > 1:
        print(
            f"Training with {args.world_size} , waiting until all world_size join before starting training"
        )
        mp.spawn(main, args=(args,), nprocs=args.gpus, join=True)
    else:
        main(0, args)
