import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset, Sampler
import argparse
import os
import math
import time
import wandb
from timm.data import Mixup
import yaml
import sys
import numpy as np
import hashlib
import shutil
import timm               
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from models import *
import random



class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0
        self.val = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def seed_everything(seed: int = 42):
    import os, random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def worker_init_fn(worker_id):
    worker_seed = 42 + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    torch.manual_seed(worker_seed)

seed_everything(42)
g = torch.Generator().manual_seed(42)
print("Random seed set to 42")

def is_special_epoch(epoch, total_epochs):
    in_last_80_percent = epoch >= int(total_epochs * 0.8)
    ends_with_9_or_last = (epoch % 10 == 9) or (epoch == total_epochs - 1)
    return in_last_80_percent and ends_with_9_or_last


def make_batch_seed_from_indices(indices_tensor, base_seed=1337):
    idx_sorted = torch.sort(indices_tensor.cpu().to(torch.long)).values.numpy()
    key_bytes = idx_sorted.tobytes()
    h = int(hashlib.sha1(key_bytes).hexdigest(), 16) & 0x7fffffff
    return (h ^ base_seed) & 0x7fffffff

def rand_bbox_here(size, lam, rng):
    H, W = size[2], size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = rng.randint(W); cy = rng.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W); bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W); bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def deterministic_cutmix_images(x, index, bbox):
    bbx1, bby1, bbx2, bby2 = bbox
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    return x


class IndexedImageFolder(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        return img, target, idx 


class BatchShuffleSampler(Sampler):
    def __init__(self, data_source, batch_size, seed=42):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_samples = len(data_source)
        self.seed = seed
        self.epoch = 0

        g0 = torch.Generator().manual_seed(seed)
        indices = torch.randperm(self.num_samples, generator=g0).tolist()
        self.batches = [indices[i:i+batch_size] for i in range(0, self.num_samples, batch_size)]

    def set_epoch(self, epoch: int):
        self.epoch = epoch

    def __iter__(self):
        g = torch.Generator().manual_seed(self.seed + self.epoch)
        perm = torch.randperm(len(self.batches), generator=g).tolist()
        for i in perm:
            yield from self.batches[i]

    def __len__(self):
        return self.num_samples



def load_config_from_yaml(args):
    config_path = args.config_path
    with open(config_path, 'r') as f:
        all_config = yaml.safe_load(f)

    dataset = args.dataset_name
    model = args.model
    ipc = str(args.ipc)

    if dataset not in all_config:
        raise ValueError(f"Dataset {dataset} not found in config")

    cfg = all_config[dataset]

    # 设置 dataset 基本信息
    for key, value in cfg.items():
        if key != "hyperparams":
            setattr(args, key, value)

    # 设置训练超参数（lr 和 eta）
    try:
        args.adamw_lr, args.eta = cfg['hyperparams'][model][ipc]
    except KeyError:
        raise ValueError(f"No hyperparams found for dataset={dataset}, model={model}, ipc={ipc}")


def build_student_model(args):
    model_dict = {
        'ResNet18': (ResNet18, models.resnet18),
        'ResNet50': (ResNet50, models.resnet50),
        'ResNet101': (ResNet101, models.resnet101),
        'MobileNetV2': (MobileNet_V2, models.mobilenet_v2),
        'ShuffleNetV2': (ShuffleNet_V2, models.shufflenet_v2_x1_0),
        'Densenet121': (DenseNet121, models.densenet121),
    }
    if args.model not in model_dict:
        raise ValueError(f"Unsupported model: {args.model}")

    small_res_model, imagenet_model_fn = model_dict[args.model]

    if args.input_size <= 64:
        model = small_res_model(args.ncls)
    else:
        if args.model == 'SwinTransformer':
            model = timm.create_model('swinv2_tiny_window8_256', pretrained=False)
        else:
            model = imagenet_model_fn(pretrained=False)
            if args.ncls != 1000:
                model.fc = nn.Linear(model.fc.in_features, args.ncls)
    return model


def parse_args():
    parser = argparse.ArgumentParser("Soft-Hard Alternating Training Script")
    parser.add_argument('--SLC', type=int, default=100)
    parser.add_argument('--ipc', default=10, type=int)
    parser.add_argument('--train-dir', type=str, default='')
    parser.add_argument('--crop-dir', type=str, default='')
    parser.add_argument('--test-dir', type=str, default='')
    parser.add_argument('--img-mode', type=str, default='bssl', help='lpld or ours')
    parser.add_argument('--train-mode', type=str, default='lpld', choices=['lpld', 'ours'])
    parser.add_argument('--model', type=str, default='ResNet18')
    parser.add_argument('--config-path', type=str, default="")
    parser.add_argument('--exp-name', type=str, default="lpld_lpld_slc100_ipc10")
    parser.add_argument('--dataset-name', default='imagenet1k', type=str)
    parser.add_argument('--output-dir', type=str, help='save dir')
    parser.add_argument('--simple', action='store_true')
    # wandb
    parser.add_argument('--wandb-api-key', type=str, default=None)
    parser.add_argument('--wandb-project', type=str, default='')
    return parser.parse_args()


args = parse_args()
load_config_from_yaml(args)
print(args)

num_classes = args.ncls
batch_size = 16
T = 20  # distill temperature

wandb.login(key=args.wandb_api_key)
wandb.init(project=args.wandb_project)
wandb.run.name = args.exp_name

if args.train_mode == 'lpld':
    args.epoch = [300]
else:
    args.epoch = [115, 70, 115] if args.SLC >= 200 else [75, 150, 75]
    
total_epochs = sum(args.epoch)

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

crop_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

hard_dataset = datasets.ImageFolder(root=args.train_dir, transform=train_transforms)

soft_dataset = IndexedImageFolder(root=args.crop_dir, transform=crop_transforms)
val_dataset  = datasets.ImageFolder(root=args.test_dir, transform=val_transforms)

sampler = BatchShuffleSampler(soft_dataset, batch_size=batch_size, seed=42)
soft_loader = DataLoader(
    soft_dataset,
    batch_size=batch_size,
    sampler=sampler,
    shuffle=False,
    worker_init_fn=worker_init_fn,
    generator=g,
    num_workers=10,
    persistent_workers=True,
    pin_memory=True
)

train_loader = DataLoader(
    hard_dataset,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=worker_init_fn,
    generator=g,
    num_workers=10,
    persistent_workers=True,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,   
    worker_init_fn=worker_init_fn,
    generator=g,
    num_workers=10,
    persistent_workers=True,
    pin_memory=True
)


mixup_fn = Mixup(mixup_alpha=0.0, cutmix_alpha=1.0, label_smoothing=0.8, num_classes=num_classes)

# ------------------ Utils ------------------
def soft_cross_entropy(logits, soft_targets):
    log_probs = F.log_softmax(logits, dim=1)
    return -(soft_targets * log_probs).sum(dim=1).mean()

def validate(model, val_loader, epoch=None):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_function = nn.CrossEntropyLoss()

    model.eval()
    t1  = time.time()
    with torch.no_grad():
        for data, target in val_loader:
            target = target.type(torch.LongTensor)
            data, target = data.cuda(), target.cuda()

            output = model(data)
            loss = loss_function(output, target)

            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            n = data.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

    logInfo = 'TEST Iter {}: loss = {:.6f},\t'.format((epoch+1) if epoch is not None else -1, objs.avg) + \
              'Top-1 err = {:.6f},\t'.format(100 - top1.avg) + \
              'Top-5 err = {:.6f},\t'.format(100 - top5.avg) + \
              'val_time = {:.6f}'.format(time.time() - t1)
    print(logInfo)

    metrics = {
        'val/loss': objs.avg,
        'val/top1': top1.avg,
        'val/top5': top5.avg,
        'val/epoch': epoch,
    }
    wandb_metrics.update(metrics)

    return objs.avg, top1.avg

def get_parameters(model):
    decay, no_decay = [], []
    for name, param in model.named_parameters():
        (decay if name.endswith('weight') and param.ndim > 1 else no_decay).append(param)
    return [dict(params=decay), dict(params=no_decay, weight_decay=0.)]

def save_checkpoint(state, is_best=True, output_dir=None, epoch=None):
    if epoch is None:
        path = output_dir + '/' + 'checkpoint.pth.tar'
    else:
        path = output_dir + f'/checkpoint.pth.tar'
    torch.save(state, path)
    if is_best:
        path_best = output_dir + '/' + 'model_best.pth.tar'
        shutil.copyfile(path, path_best)

def hard_train(model, loader, optimizer, epoch):
    objs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.train()
    t1 = time.time()
    for inputs, labels in loader:  
        if len(inputs) % 2 != 0:  
            inputs = inputs[:-1]
            labels = labels[:-1]
        inputs, labels_mixed = mixup_fn(inputs, labels)
        inputs, labels_mixed = inputs.to(device), labels_mixed.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = soft_cross_entropy(outputs, labels_mixed)
        loss.backward()
        n = inputs.size(0)
        prec1, prec5 = accuracy(outputs, labels.cuda(), topk=(1, 5))
        objs.update(loss.item(), n)
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)
        optimizer.step()

    metrics = {
        "train/loss": objs.avg,
        "train/Top1": top1.avg,
        "train/Top5": top5.avg,
        "train/lr": scheduler.get_last_lr()[0],
        "train/epoch": epoch,
    }
    wandb_metrics.update(metrics)

    printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(epoch, scheduler.get_last_lr()[0], objs.avg) + \
                'Top-1 err = {:.6f},\t'.format(100 - top1.avg) + \
                'Top-5 err = {:.6f},\t'.format(100 - top5.avg) + \
                'train_time = {:.6f}'.format((time.time() - t1))
    print(printInfo)
    return objs.avg


# ------------------ Model & Optim ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_student_model(args).to(device)
teacher = models.resnet18(pretrained=True).to(device)
if args.img_mode == 'rded':
    teacher.eval()
else:
    teacher.train()

if torch.cuda.device_count() > 1:
    print(f"🚀 Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

optimizer = torch.optim.AdamW(get_parameters(model), lr=args.adamw_lr, weight_decay=1e-4)
scheduler = LambdaLR(optimizer, lambda step: 0.5 * (1 + math.cos(math.pi * step / total_epochs / args.eta)) if step <= total_epochs else 0)

# ------------------ Training ------------------
curr_epoch = 0
global wandb_metrics
wandb_metrics = {}
sampler_epoch = 0

print(args.epoch)
for stage_idx, stage_epochs in enumerate(args.epoch):
    model.train()
    if stage_idx % 2 == 0:
        print(f"Soft Training (Stage {stage_idx+1})")
        per_step_epoch = (args.ipc * num_classes + batch_size - 1) // batch_size
        max_steps = per_step_epoch * stage_epochs
        step = 0
        criterion = nn.KLDivLoss(reduction='batchmean')
        objs = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        t1 = time.time()
        while step < max_steps:
            sampler_epoch += 1
            sampler.set_epoch(sampler_epoch)
            for inputs, labels, idxs in soft_loader:
                if step >= max_steps:
                    break
                n_cur = inputs.size(0)
                if n_cur % 2 != 0:
                    inputs = inputs[:-1]
                    labels = labels[:-1]
                    idxs   = idxs[:-1]
                    n_cur  = inputs.size(0)

                seed = make_batch_seed_from_indices(idxs, base_seed=20250918)
                rng = np.random.RandomState(seed)
                lam = rng.beta(1.0, 1.0)
                bbox = rand_bbox_here((n_cur, 3, args.input_size, args.input_size), lam, rng)

                pair_index = torch.from_numpy(rng.permutation(n_cur)).to(inputs.device)

                inputs = deterministic_cutmix_images(inputs, pair_index, bbox)

                inputs = inputs.to(device)

                optimizer.zero_grad()
                student_out = model(inputs) / T
                teacher_out = teacher(inputs) / T
                probs_T = F.softmax(teacher_out, dim=1)
                loss = criterion(F.log_softmax(student_out, dim=1), probs_T)
                loss.backward()
                n = inputs.size(0)
                prec1, prec5 = accuracy(student_out, labels.cuda(), topk=(1, 5))
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)
                optimizer.step()
                step += 1

                if step % per_step_epoch == 0:
                    metrics = {
                        "train/loss": objs.avg,
                        "train/Top1": top1.avg,
                        "train/Top5": top5.avg,
                        "train/lr": scheduler.get_last_lr()[0],
                        "train/epoch": curr_epoch,
                    }
                    wandb_metrics.update(metrics)

                    printInfo = 'TRAIN Iter {}: lr = {:.6f},\tloss = {:.6f},\t'.format(
                        curr_epoch, scheduler.get_last_lr()[0], objs.avg
                    ) + 'Top-1 err = {:.6f},\t'.format(100 - top1.avg) + \
                        'Top-5 err = {:.6f},\t'.format(100 - top5.avg) + \
                        'train_time = {:.6f}'.format((time.time() - t1))
                    print(printInfo)
                    t1 = time.time()

                    scheduler.step()
                    should_validate = (curr_epoch % 10 == 0 or curr_epoch == total_epochs - 1) if not args.simple else is_special_epoch(curr_epoch, total_epochs)
                    if should_validate:
                        validate(model, val_loader, epoch=curr_epoch)
                        model.train()

                    curr_epoch += 1
                    wandb.log(wandb_metrics)
                    wandb_metrics = {}

                    objs = AverageMeter()
                    top1 = AverageMeter()
                    top5 = AverageMeter()
                    t1 = time.time()

    else:
        print(f"Hard Training (Stage {stage_idx+1})")
        for _ in range(stage_epochs):
            train_loss = hard_train(model, train_loader, optimizer, curr_epoch)
            scheduler.step()
            should_validate = (curr_epoch % 10 == 0 or curr_epoch == total_epochs - 1) if not args.simple else is_special_epoch(curr_epoch, total_epochs)
            if should_validate:
                validate(model, val_loader, epoch=curr_epoch)
            wandb.log(wandb_metrics)
            curr_epoch += 1

# ------------------ Save ------------------
args.output_dir = os.path.join(args.output_dir, args.dataset_name, args.exp_name)
os.makedirs(args.output_dir, exist_ok=True)
save_checkpoint({
    'epoch': total_epochs,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
}, is_best=True, output_dir=args.output_dir, epoch=total_epochs)

wandb.finish()
