import os
import math
import argparse
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import datetime
import gc

from BartonTwins import BartonTwins, BartonTwins_imagenet
from BartonTwins_spiking import BartonTwinsSpiking, BartonTwinsSpiking_imagenet
from model import load_optimizer, save_model
from utils import yaml_config_hook

from modules.transformations import DataTransforms, DataTransforms_imagenet
from modules import get_resnet, get_resnet_spiking, modify_resnet_model, get_vgg, get_vgg_spiking, LogisticRegression
from modules.spike_layer import MixedLIF, LIFt

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from dataset import TinyImageNetDataset


def compute_kl_divergence(v1, v2, bins=50, eps=1e-8):
    min_val = min(v1.min(), v2.min()).item()
    max_val = max(v1.max(), v2.max()).item()
    hist1 = torch.histc(v1, bins=bins, min=min_val, max=max_val)
    hist2 = torch.histc(v2, bins=bins, min=min_val, max=max_val)

    p = hist1 / hist1.sum()
    q = hist2 / hist2.sum()

    p = p + eps
    q = q + eps

    kl_div = torch.sum(p * torch.log(p / q)).item()
    return kl_div
            

def cleanup():
    dist.destroy_process_group()


def inference_dual_models(args, loader, model1, model2, device):
    kl_divs = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        with torch.no_grad():
            h1, _, _, _ = model1(x, x)
            h2, _, _, _ = model2(x, x)

            if args.spiking:
                h1 = h1.mean(0)
                h2 = h2.mean(0)

            h1 = h1.detach()
            h2 = h2.detach()

            for i in range(h1.size(0)):
                vec1 = h1[i]
                vec2 = h2[i]
                kl = compute_kl_divergence(vec1, vec2)
                kl_divs.append(kl)

        if step % 10 == 0:
            print(f"Step [{step}/{len(loader)}] - Avg KL so far: {np.mean(kl_divs):.4f}")

    print(f"Final average KL divergence: {np.mean(kl_divs):.4f}")
    return kl_divs

def inference_dual_models_per_timestep(args, loader, model1, model2, device):
    timestep_kl_stats = []  # list of list: timestep -> [KL1, KL2, ...]

    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        with torch.no_grad():
            h1, _, _, _ = model1(x, x)  # shape: [T, B, C]
            h2, _, _, _ = model2(x, x)

            T = h1.shape[0]
            B = h1.shape[1]

            # make sure init
            while len(timestep_kl_stats) < T:
                timestep_kl_stats.append([])

            for t in range(T):
                for i in range(B):
                    vec1 = h1[t][i]
                    vec2 = h2[t][i]
                    kl = compute_kl_divergence(vec1, vec2)
                    timestep_kl_stats[t].append(kl)

        if step % 5 == 0:
            avg_kl = [np.mean(kls) for kls in timestep_kl_stats]
            print(f"[Step {step}] timestep KL (mean): {[round(k, 4) for k in avg_kl]}")

    print(f"\n=== Per-Timestep KL Divergence ===")
    for t, kls in enumerate(timestep_kl_stats):
        print(f"Timestep {t}: Mean KL = {np.mean(kls):.6f}, Std = {np.std(kls):.6f}")

    return timestep_kl_stats



def main(rank, world_size, args):
    # init DDP
    args.device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    if world_size > 1:
        dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=120))
        torch.cuda.set_device(rank)

    if rank == 0:
        print("Now start loading dataset")
    if args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=True,
            download=True,
            transform=DataTransforms(size=args.image_size).train_evaluation_transform,
        )
        test_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            train=False,
            download=True,
            transform=DataTransforms(size=args.image_size).test_transform,
        )
    else:
        raise NotImplementedError

    # Create a 1%/10%/100% data subset using the random indices
    num_samples = int(len(train_dataset) * 0.01)
    indices = torch.randperm(len(train_dataset))[:num_samples]
    train_subset = torch.utils.data.Subset(train_dataset, indices)

    if world_size > 1:
        sampler = DistributedSampler(train_subset, num_replicas=world_size, rank=rank, shuffle=True)
    else:
        sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=args.logistic_batch_size,
        shuffle=(sampler is None),
        drop_last=True,
        num_workers=args.num_workers,
        sampler=sampler
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.logistic_batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
    )

    backbone = get_resnet_spiking(args.model, args.timestep, args.sync_norm, MixedLIF, args.n_classes)
    n_features = backbone.fc.in_features  # get dimensions of fc layer
    backbone.fc = nn.Identity()
    if "CIFAR" in args.dataset and args.image_size==32:
        backbone = modify_resnet_model(backbone)
    bt_model_1 = BartonTwinsSpiking_imagenet(backbone, in_dim=n_features, out_dim=args.projection_dim, act_func=MixedLIF,
                                       timestep=args.timestep)
    checkpoint = torch.load("/home/cxz760/selfsupervised_SpikingNN/barlowtwins_SNN/save/cifar10/resnet34_spk_c_mixlif/checkpoint_epoch_594.tar", map_location=args.device)
    bt_model_1.load_state_dict(checkpoint['model_state_dict'], strict=False)
    print(f"Model CTL loaded from checkpoint.")

    backbone = get_resnet_spiking(args.model, args.timestep, args.sync_norm, MixedLIF, args.n_classes)
    n_features = backbone.fc.in_features  # get dimensions of fc layer
    backbone.fc = nn.Identity()
    if "CIFAR" in args.dataset and args.image_size == 32:
        backbone = modify_resnet_model(backbone)
    bt_model_2 = BartonTwinsSpiking_imagenet(backbone, in_dim=n_features, out_dim=args.projection_dim,
                                             act_func=MixedLIF,
                                             timestep=args.timestep)
    checkpoint = torch.load("/home/cxz760/selfsupervised_SpikingNN/barlowtwins_SNN/save/cifar10/resnet34_spk_1l_mixlif/checkpoint_epoch_629.tar", map_location=args.device)
    bt_model_2.load_state_dict(checkpoint['model_state_dict'], strict=False)
    print(f"Model CTL loaded from checkpoint.")

    bt_model_1 = bt_model_1.to(args.device)
    bt_model_2 = bt_model_2.to(args.device)
    bt_model_1.eval()
    bt_model_2.eval()

    print("### Comparing timestep-wise KL divergence between CTL and BTL ###")
    timestep_kl = inference_dual_models_per_timestep(args, test_loader, bt_model_1, bt_model_2, args.device)

    cleanup()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    config = yaml_config_hook("./config/config.yaml")
    for k, v in config.items():
        parser.add_argument(f"--{k}", default=v, type=type(v))

    args = parser.parse_args()
    args.lr = float(args.lr)
    print(vars(args))

    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "1234"

    world_size = torch.cuda.device_count()
    if world_size > 1:
        torch.multiprocessing.spawn(main, args=(world_size, args), nprocs=world_size, join=True)
    else:
        main(0, 1, args)