import os, argparse, random
from playgroundImageNet import *
from osbc_prune import *
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from compression_cifar100 import FlattenedVGG, vgg16_bn, RepeatTransform
# parallellism
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

BATCH=128

#percentage of the pruned model you want to fine-tune
TARGET_PRUNE = "0.957"

MODEL_SAVE_PATH = "./models/cifar100_vgg"

def loadCIFAR100(data_path, T):
    """Return DataLoader for Cifar100 val/ subset."""
    MEAN = [0.5071, 0.4867, 0.4408]
    STD = [0.2673, 0.2564, 0.2762]
    tf_train_cifar100 = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        Cutout(n_holes=1, length=16),
        transforms.Normalize(MEAN, STD),
        RepeatTransform(T)
    ])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
        RepeatTransform(T)
    ])
    traindir = os.path.join(data_path, "train")
    valdir = os.path.join(data_path, "val")
    dataset_train = CIFAR100(root=traindir, transform=tf_train_cifar100, train=True, download=True)
    dataset_test = CIFAR100(root=valdir, transform=valid_transform, train=False, download=True)
    
    return dataset_train, dataset_test
# --------------------------------------------------------------------------- #
#  helpers
# --------------------------------------------------------------------------- #
def set_random_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def freeze_batchnorm(model: nn.Module, *, affine: bool = True) -> None:
    """
    Puts every BatchNorm* layer in eval mode (no running-stat updates).
    Optionally freezes the affine weights/biases as well.
    """
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            m.eval()                # stop running-stat updates
            if affine:              # optionally freeze γ, β
                m.weight.requires_grad = False
                m.bias.requires_grad  = False

def apply_mask_hook(param: torch.Tensor, mask: torch.Tensor):
    """Block gradients for pruned weights."""
    assert mask.shape == param.shape
    def hook(grad):
        return grad * mask.to(param.device)
    param.register_hook(hook)


def inspect_bns(model):
    for name, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            print(f"— {name} —")
            print(f"  train mode     : {m.training}")
            print(f"  weight grad    : {m.weight.requires_grad}")
            print(f"  bias   grad    : {m.bias  .requires_grad}")
            # show a summary statistic of the running stats
            print(f"  running_mean   : {m.running_mean.mean().item():.4f} ± {m.running_mean.std().item():.4f}")
            print(f"  running_var    : {m.running_var.mean().item():.4f} ± {m.running_var.std().item():.4f}")
            print(f"  momentum       : {m.momentum}\n")

def inspect_layer_sparsity(flat_model):
    masks = {}
    total_zero_count = 0
    total_weight_count = 0
    for name, layer in find_layers(flat_model.layer).items():
        w = layer.weight
        masks[name] = torch.abs(w) > 5e-7
        layer_weight_nonzero = torch.masked_select(w, masks[name])
        total_zero_count += w.numel()-len(layer_weight_nonzero)
        total_weight_count += w.numel()            
        print(f'Layer {name}, {1.0-(torch.sum(masks[name]) / w.numel()).item():.5f} {torch.mean(layer_weight_nonzero).item():.5f} {torch.var(layer_weight_nonzero).item():.5f}')
    print(f'Total sparsity: {total_zero_count}, {total_weight_count}, {total_zero_count/total_weight_count:.4f}')

# --------------------------------------------------------------------------- #
#  worker - one per GPU
# --------------------------------------------------------------------------- #
def main(local_rank: int, args):
    # ---------- distributed initialisation ----------
    dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    set_random_seed(TORCH_SEED + local_rank)

    # ---------- build / load model ----------
    model_path = f"./models/cifar100_vgg/{TARGET_PRUNE}_pruned_model.pth"
    model = vgg16_bn(
        num_classes=100
    ).to(device)
    model.load_state_dict(torch.load(model_path, weights_only=False, map_location=DEV).state_dict())
    
    # flatten helper _before_ wrapping in DDP so attributes still resolve
    flat = FlattenedVGG(model).eval()
    # Compute masks once (they’re identical on every rank)
    model_masks = {}
    for name, layer in find_layers(flat.layer).items():
        w = layer.weight
        model_masks[name] = torch.abs(w) > 5e-7
    if local_rank == 0:
        inspect_layer_sparsity(flat)
    
    # ---------- DistributedDataParallel ----------
    model = DDP(model, device_ids=[local_rank], output_device=local_rank)
    flat  = FlattenedVGG(model.module).eval()   # point at wrapped model

    # attach gradient mask hooks (must run *after* DDP wrapping)
    for name, mask in model_masks.items():
        apply_mask_hook(flat.layer[name].weight, mask)

    # ---------- data ----------
    T = 5
    lr_ = 5e-3
    batch_size_ = 128
    wd_ = 0
    epoch_ = 50
    data_path = "./datas"
    train_ds, test_ds = loadCIFAR100(data_path, T)

    # Each GPU sees a unique shard
    train_sampler = DistributedSampler(train_ds, shuffle=True, drop_last=True)
    test_sampler  = DistributedSampler(test_ds, shuffle=False)

    finetune_loader = DataLoader(
        train_ds, batch_size=batch_size_, sampler=train_sampler,
        num_workers=8, pin_memory=True, persistent_workers=False, prefetch_factor=2,
    )

    if local_rank == 0:
        model.eval()
        pre_finetune = validate_model(
            model.module,                    # unwrap DDP
            T, test_ds, pruning_perc=f"{TARGET_PRUNE}_tuned",
            out_dir="", batch=BATCH,
        )
        print("pre_finetune", pre_finetune)

    # ---------- optimiser / loss ----------
    optimizer  = torch.optim.SGD(model.parameters(), lr=lr_,
                                 momentum=0.9, weight_decay=wd_)
    # scheduler -------------------------------------------
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 40], gamma=0.1)
    criterion  = nn.CrossEntropyLoss(label_smoothing=0.0)
    functional.set_step_mode(model, 'm')
    model.train()
    freeze_batchnorm(model,affine=False)
    
    for e in range(epoch_):
        # ---------- one epoch fine-tune ----------
        model.train()
        train_sampler.set_epoch(e)          # shuffle shards the same way on every rank
        pbar = tqdm(finetune_loader) if local_rank == 0 else finetune_loader
        ep_samples = 0
        ep_correct = 0.0

        for frame, label in pbar:
            frame = frame.cuda(non_blocking=True).transpose(0, 1)   # T,N,C,H,W
            label = label.cuda(non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type="cuda"):
                out_t = model(frame)          # shape: T,N,num_classes
                out   = out_t.mean(0)
                loss  = criterion(out, label)

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                ep_samples += label.size(0)
                ep_correct += (out.argmax(dim=1) == label).float().sum().item()
                functional.reset_net(model)

            if local_rank == 0:
                pbar.set_postfix({
                    "loss": f"{loss.item():.4f}",
                    "lr":   f"{scheduler.get_last_lr()[0]:.6f}",
                    "acc" : f"{ep_correct / ep_samples:.4f}",
                })
        scheduler.step()


        # ---------- validation (rank-0 only) ----------
        if local_rank == 0:
            model.eval()
            save_dir = ""
            if e == epoch_-1:
                save_dir = MODEL_SAVE_PATH

            if (e+1) % 5 == 0:
                post_finetune = validate_model(
                    model.module,                    # unwrap DDP
                    T, test_ds, pruning_perc=f"{TARGET_PRUNE}_tuned_epoch{e}",
                    out_dir=save_dir, batch=BATCH,
                )
                inspect_layer_sparsity(flat)
                # Optionally save checkpoint
                print(f"{e}: post_finetune", post_finetune)

    dist.barrier()
    dist.destroy_process_group()


# --------------------------------------------------------------------------- #
#  launcher
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # (Optional) extra CLI flags go here …
    args = parser.parse_args()

    world_size = torch.cuda.device_count()
    if world_size < 2:
        raise RuntimeError("At least 2 CUDA devices are required for DDP.")

    # torchrun sets LOCAL_RANK for each spawned process – honour it if present
    local_rank_env = int(os.environ.get("LOCAL_RANK", -1))
    if local_rank_env >= 0:
        main(local_rank_env, args)
    else:
        # Fallback: manual spawn (rarely needed when using torchrun)
        torch.multiprocessing.spawn(
            main, args=(args,), nprocs=world_size, join=True
        )
