import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset

from FS_ddro import DS_FedDRO_M, eval_loss_acc


def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def build_st_indices(targets, num_classes, minority_keep):
    targets = np.array(targets)
    keep = []
    half = num_classes // 2
    for c in range(num_classes):
        idx = np.where(targets == c)[0]
        if c < half:
            idx = idx[:minority_keep]
        keep.extend(idx.tolist())
    return np.array(sorted(keep), dtype=np.int64)


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.bn2(self.conv2(out))
        out = out + self.shortcut(x)
        return F.relu(out, inplace=True)


class ResNetCIFAR(nn.Module):
    def __init__(self, num_blocks, num_classes):
        super().__init__()
        self.in_planes = 16
        self.conv1 = conv3x3(3, 16, 1)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(64, num_blocks[2], stride=2)
        self.fc = nn.Linear(64, num_classes)
        self._init_weights()

    def _make_layer(self, planes, n, stride):
        strides = [stride] + [1] * (n - 1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size(3))
        out = out.view(out.size(0), -1)
        return self.fc(out)


def ResNet20(num_classes):
    return ResNetCIFAR([3, 3, 3], num_classes=num_classes)


def ResNet56(num_classes):
    return ResNetCIFAR([9, 9, 9], num_classes=num_classes)


def ResNet110(num_classes):
    return ResNetCIFAR([18, 18, 18], num_classes=num_classes)


def centralized_pretrain(model, train_loader, train_eval_loader, test_loader, device, epochs=1, lr=0.0012):
    if epochs <= 0:
        return

    opt = torch.optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=0.9,
        weight_decay=5e-4,
        nesterov=True,
    )
    ce = nn.CrossEntropyLoss()

    for ep in range(1, epochs + 1):
        model.train()
        for x, y in train_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            loss = ce(model(x), y)
            if not torch.isfinite(loss):
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()

        tr_loss, tr_acc = eval_loss_acc(model, train_eval_loader, device)
        te_loss, te_acc = eval_loss_acc(model, test_loader, device)
        print(
            f"[PRETRAIN {ep:02d}/{epochs:02d}] "
            f"TrainAcc {tr_acc:6.2f} | TestAcc {te_acc:6.2f} "
            f"| TrainLoss {tr_loss:.4f} | TestLoss {te_loss:.4f} "
            f"| lr={opt.param_groups[0]['lr']:.5f}",
            flush=True,
        )


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, default="cifar100st", choices=["cifar10", "cifar100", "cifar10st", "cifar100st"])
    p.add_argument("--model_name", type=str, default="resnet110", choices=["resnet20", "resnet56", "resnet110"])
    p.add_argument("--data_root", type=str, default="./data")
    p.add_argument("--epochs", type=int, default=120)
    p.add_argument("--workers", type=int, default=2)
    p.add_argument("--batch_size", type=int, default=256)

    p.add_argument("--num_users", type=int, default=32)
    p.add_argument("--local_bs", type=int, default=128)
    p.add_argument("--local_ep", type=int, default=4)

    p.add_argument("--K", type=int, default=1)
    p.add_argument("--I", type=int, default=1)

    p.add_argument("--base_lr", type=float, default=0.065)
    p.add_argument("--min_lr", type=float, default=0.002)
    p.add_argument("--warmup_rounds", type=int, default=20)
    p.add_argument("--momentum", type=float, default=0.9)
    p.add_argument("--weight_decay", type=float, default=5e-4)
    p.add_argument("--grad_clip", type=float, default=5.0)

    p.add_argument("--lamda", type=float, default=1.0)
    p.add_argument("--beta_y", type=float, default=0.035)
    p.add_argument("--gamma_x", type=float, default=0.85)
    p.add_argument("--gamma_y", type=float, default=0.12)
    p.add_argument("--y_clip", type=float, default=8.0)

    p.add_argument("--label_smoothing", type=float, default=0.0)
    p.add_argument("--mixup_alpha", type=float, default=0.04)
    p.add_argument("--mixup_start_round", type=int, default=85)
    p.add_argument("--mixup_full_round", type=int, default=110)

    p.add_argument("--server_rehearsal_steps", type=int, default=6)
    p.add_argument("--server_rehearsal_lr", type=float, default=0.004)
    p.add_argument("--server_rehearsal_start_round", type=int, default=20)
    p.add_argument("--server_label_smoothing", type=float, default=0.0)

    p.add_argument("--eval_num_clients", type=int, default=32)
    p.add_argument("--pretrain_epochs", type=int, default=1)
    p.add_argument("--pretrain_lr", type=float, default=0.0012)
    p.add_argument("--random_seed", type=int, default=40)

    p.add_argument("--wandb", action="store_true")
    p.add_argument("--wandb_project", type=str, default="avg")
    p.add_argument("--wandb_entity", type=str, default="hq1351-wayne-state-university")

    args = p.parse_args()
    set_all_seeds(args.random_seed)
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[INFO] device={args.device}", flush=True)

    if "cifar10" in args.dataset:
        args.num_classes = 10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
        dataset_cls = torchvision.datasets.CIFAR10
    else:
        args.num_classes = 100
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
        dataset_cls = torchvision.datasets.CIFAR100

    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    eval_tf = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    base_train_aug = dataset_cls(root=args.data_root, train=True, download=True, transform=train_tf)
    base_train_eval = dataset_cls(root=args.data_root, train=True, download=True, transform=eval_tf)
    test_data = dataset_cls(root=args.data_root, train=False, download=True, transform=eval_tf)

    if args.dataset.endswith("st"):
        minority_keep = 500 if args.num_classes == 10 else 100
        st_indices = build_st_indices(base_train_aug.targets, num_classes=args.num_classes, minority_keep=minority_keep)
        train_data = Subset(base_train_aug, st_indices.tolist())
        train_eval_data = Subset(base_train_eval, st_indices.tolist())
    else:
        train_data = base_train_aug
        train_eval_data = base_train_eval

    print(f"[INFO] dataset={args.dataset}", flush=True)
    print(f"[INFO] train size={len(train_data)}", flush=True)
    print(f"[INFO] test size={len(test_data)}", flush=True)

    train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=args.workers, pin_memory=True)
    train_eval_loader = DataLoader(train_eval_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    if args.model_name == "resnet20":
        model = ResNet20(num_classes=args.num_classes).to(args.device)
    elif args.model_name == "resnet56":
        model = ResNet56(num_classes=args.num_classes).to(args.device)
    else:
        model = ResNet110(num_classes=args.num_classes).to(args.device)

    if args.pretrain_epochs > 0:
        centralized_pretrain(
            model=model,
            train_loader=train_loader,
            train_eval_loader=train_eval_loader,
            test_loader=test_loader,
            device=args.device,
            epochs=args.pretrain_epochs,
            lr=args.pretrain_lr,
        )

    DS_FedDRO_M(
        args=args,
        model=model,
        train_data=train_data,
        train_eval_data=train_eval_data,
        test_loader=test_loader,
    )


if __name__ == "__main__":
    main()