#!/usr/bin/env python3
import os
import csv
import time
import copy
import argparse
import warnings
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset

# -------------------------
# ResNet20 for CIFAR
# -------------------------
def _conv3x3(in_ch, out_ch, stride=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = _conv3x3(in_ch, out_ch, stride)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = _conv3x3(out_ch, out_ch, 1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.shortcut = nn.Identity()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet20(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.in_ch = 16
        self.conv1 = _conv3x3(3, 16, 1)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, 3, 1)
        self.layer2 = self._make_layer(32, 3, 2)
        self.layer3 = self._make_layer(64, 3, 2)
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, out_ch, n_blocks, stride):
        strides = [stride] + [1] * (n_blocks - 1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_ch, out_ch, s))
            self.in_ch = out_ch
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size(-1))
        out = out.view(out.size(0), -1)
        return self.fc(out)

def resnet20(num_classes):
    return ResNet20(num_classes=num_classes)

# -------------------------
# Dataset wrappers
# -------------------------
class DatasetWithIndex(Dataset):
    def __init__(self, base):
        self.base = base

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        x, y = self.base[idx]
        return idx, x, y

class DatasetSplit(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = list(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.dataset[self.indices[i]]

def iid_equal_split(n, num_users, seed=0):
    rng = np.random.default_rng(int(seed))
    idxs = np.arange(n)
    rng.shuffle(idxs)
    splits = np.array_split(idxs, int(num_users))
    return {u: splits[u].tolist() for u in range(int(num_users))}

def fedavg_weighted(local_weights, sizes):
    total = float(sum(sizes))
    if total <= 0:
        raise ValueError("Total client sizes must be > 0")
    w_avg = {}
    for k in local_weights[0].keys():
        t0 = local_weights[0][k]
        if not torch.is_floating_point(t0):
            w_avg[k] = t0.clone()
            continue
        acc = torch.zeros_like(t0)
        for w, n in zip(local_weights, sizes):
            acc += w[k] * (float(n) / total)
        w_avg[k] = acc
    return w_avg

# -------------------------
# CIFAR transforms + CIFAR-ST (paper style)
# CIFAR-ST train: for first half classes keep LAST 100 images only
# test: unchanged
# -------------------------
def cifar_transforms():
    import torchvision.transforms as T
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2023, 0.1994, 0.2010)
    train_tfm = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    eval_tfm = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    return train_tfm, eval_tfm

def build_cifar_st(dataset: str, root: str, train: bool, transform):
    import torchvision
    ds = dataset.lower()
    if ds == "cifar10":
        base = torchvision.datasets.CIFAR10(root=root, train=train, download=True, transform=transform)
        num_classes = 10
    elif ds == "cifar100":
        base = torchvision.datasets.CIFAR100(root=root, train=train, download=True, transform=transform)
        num_classes = 100
    else:
        raise ValueError("dataset must be cifar10 or cifar100")

    if not train:
        return base, num_classes

    targets = np.array(base.targets, dtype=np.int64)
    minority_classes = set(np.arange(num_classes // 2).tolist())

    keep = []
    for c in range(num_classes):
        idxs = np.where(targets == c)[0]
        if c in minority_classes:
            keep.extend(idxs[-100:].tolist())  # paper: last 100
        else:
            keep.extend(idxs.tolist())

    keep = sorted(keep)
    return Subset(base, keep), num_classes

# -------------------------
# Eval
# -------------------------
@torch.no_grad()
def eval_acc_loss(model, loader, device):
    model.eval()
    ce = nn.CrossEntropyLoss()
    correct, total, loss_sum = 0, 0, 0.0
    for _, x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = ce(logits, y)
        loss_sum += float(loss) * x.size(0)
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum().item())
        total += int(y.numel())
    return loss_sum / max(1, total), 100.0 * correct / max(1, total)

def run_fedavg_fig1(args):
    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=FutureWarning)

    torch.manual_seed(int(args.seed))
    np.random.seed(int(args.seed))

    dev = torch.device(args.device if torch.cuda.is_available() else "cpu")
    train_tfm, eval_tfm = cifar_transforms()

    # Paper setting: CIFAR*-ST train, unchanged test
    train_base_aug, num_classes = build_cifar_st(args.dataset, args.data_root, train=True, transform=train_tfm)
    train_base_eval, _          = build_cifar_st(args.dataset, args.data_root, train=True, transform=eval_tfm)
    test_base, _                = build_cifar_st(args.dataset, args.data_root, train=False, transform=eval_tfm)

    train_ds = DatasetWithIndex(train_base_aug)     # for local SGD
    train_eval_ds = DatasetWithIndex(train_base_eval)  # for train accuracy
    test_ds = DatasetWithIndex(test_base)           # test accuracy

    # Equal split across clients (K=8)
    dict_users = iid_equal_split(len(train_ds), args.num_users, seed=args.seed)

    train_eval_loader = DataLoader(
        train_eval_ds,
        batch_size=args.eval_bs,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=False,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=args.eval_bs,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=False,
    )

    model_global = resnet20(num_classes).to(dev)
    w_global = model_global.state_dict()
    ce = nn.CrossEntropyLoss()

    os.makedirs(os.path.dirname(args.out_csv) or ".", exist_ok=True)

    fields = [
        "round", "base_lr", "cur_lr",
        "train_loss", "train_acc",
        "test_loss", "test_acc",
        "best_test", "best_round",
        "elapsed_sec",
    ]

    wb = None
    if args.wandb:
        import wandb
        wb = wandb
        wb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            config={
                "algo": "FedAvg-ERM",
                "dataset": f"{args.dataset.upper()}-ST",
                "clients": args.num_users,
                "local_ep": args.local_ep,
                "local_bs": args.local_bs,
                "rounds": args.rounds,
                "lr": args.lr,
                "lr_drop_round": 90,
                "lr_drop_factor": 0.1,
                "weight_decay": args.weight_decay,
                "momentum": args.momentum,
                "seed": args.seed,
            },
        )

    t0 = time.time()
    best_test, best_round = -1.0, -1

    with open(args.out_csv, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fields)
        writer.writeheader()

        for rnd in range(1, args.rounds + 1):
            # paper LR schedule: /10 after round 90
            cur_lr = args.lr * (0.1 if rnd > 90 else 1.0)

            local_weights, local_sizes = [], []

            for u in range(args.num_users):
                idxs = dict_users[u]
                if len(idxs) == 0:
                    continue

                model_local = copy.deepcopy(model_global).to(dev)
                model_local.load_state_dict(w_global)
                model_local.train()

                loader_u = DataLoader(
                    DatasetSplit(train_ds, idxs),
                    batch_size=args.local_bs,
                    shuffle=True,
                    num_workers=args.workers,
                    pin_memory=torch.cuda.is_available(),
                    drop_last=True,
                )

                opt = torch.optim.SGD(
                    model_local.parameters(),
                    lr=cur_lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                )

                for _ in range(args.local_ep):
                    for _, x, y in loader_u:
                        x = x.to(dev, non_blocking=True)
                        y = y.to(dev, non_blocking=True)
                        loss = ce(model_local(x), y)
                        opt.zero_grad(set_to_none=True)
                        loss.backward()
                        opt.step()

                local_weights.append(copy.deepcopy(model_local.state_dict()))
                local_sizes.append(len(idxs))

            # FedAvg aggregate
            w_global = fedavg_weighted(local_weights, local_sizes)
            model_global.load_state_dict(w_global)

            # Evaluate (train on ST-train, test on standard test)
            train_loss, train_acc = eval_acc_loss(model_global, train_eval_loader, dev)
            test_loss, test_acc = eval_acc_loss(model_global, test_loader, dev)

            if test_acc > best_test:
                best_test, best_round = test_acc, rnd

            row = {
                "round": rnd,
                "base_lr": args.lr,
                "cur_lr": cur_lr,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "test_loss": test_loss,
                "test_acc": test_acc,
                "best_test": best_test,
                "best_round": best_round,
                "elapsed_sec": time.time() - t0,
            }
            writer.writerow(row)

            print(
                f"[{args.dataset.upper()}-ST | iid(equal)] Round {rnd:3d} | lr={cur_lr:.5f} | "
                f"TrainAcc {train_acc:6.2f} TestAcc {test_acc:6.2f} | BestTest {best_test:6.2f} (r{best_round})"
            )

            if wb is not None:
                wb.log(row, step=rnd)

    if wb is not None:
        wb.summary["best_test_acc"] = float(best_test)
        wb.summary["best_round"] = int(best_round)
        wb.finish()

    return best_test, best_round

def build_argparser():
    p = argparse.ArgumentParser("FedAvg-ERM baseline in paper Figure-1 setting (CIFAR*-ST)")
    p.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "cifar100"])
    p.add_argument("--data_root", type=str, default="./data")

    p.add_argument("--rounds", type=int, default=120)
    p.add_argument("--num_users", type=int, default=8)
    p.add_argument("--local_ep", type=int, default=1)     # I=1
    p.add_argument("--local_bs", type=int, default=16)    # paper: 16 per client
    p.add_argument("--eval_bs", type=int, default=256)

    p.add_argument("--lr", type=float, default=0.05)
    p.add_argument("--weight_decay", type=float, default=5e-4)
    p.add_argument("--momentum", type=float, default=0.9)

    p.add_argument("--seed", type=int, default=40)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--device", type=str, default="cuda")

    p.add_argument("--out_csv", type=str, default="runs/fedavg_fig1.csv")

    # W&B
    p.add_argument("--wandb", action="store_true")
    p.add_argument("--wandb_project", type=str, default="avg")
    p.add_argument("--wandb_entity", type=str, default="hq1351-wayne-state-university")

    return p

if __name__ == "__main__":
    args = build_argparser().parse_args()
    run_fedavg_fig1(args)