import copy
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset


class DatasetSplit(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = list(indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset[int(self.indices[idx])]


def equal_split_indices(n, num_users, seed=0):
    rng = np.random.default_rng(seed)
    perm = rng.permutation(n)
    shards = np.array_split(perm, num_users)
    return {u: np.array(shards[u], dtype=np.int64) for u in range(num_users)}


def weighted_avg_state_dict(state_dicts, weights):
    out = copy.deepcopy(state_dicts[0])
    for k in out.keys():
        acc = None
        for sd, w in zip(state_dicts, weights):
            term = sd[k] * float(w)
            acc = term if acc is None else acc + term
        out[k] = acc
    return out


def step_round_lr(round_idx, base_lr, drop_epoch=90, drop_factor=0.1):
    if round_idx > int(drop_epoch):
        return float(base_lr) * float(drop_factor)
    return float(base_lr)


def local_train_fedavg(
    model,
    trainloader,
    lr,
    local_ep=1,
    momentum=0.9,
    weight_decay=5e-4,
    grad_clip=5.0,
):
    device = next(model.parameters()).device
    model.train()

    opt = torch.optim.SGD(
        model.parameters(),
        lr=float(lr),
        momentum=float(momentum),
        weight_decay=float(weight_decay),
        nesterov=True,
    )
    ce = nn.CrossEntropyLoss()

    for _ in range(int(local_ep)):
        for x, y in trainloader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = ce(logits, y)

            if not torch.isfinite(loss):
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float(grad_clip))
            opt.step()

    return copy.deepcopy(model.state_dict())


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + self.shortcut(x)
        return torch.relu(out)


class ResNetCIFAR(nn.Module):
    def __init__(self, num_blocks, num_classes):
        super().__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(64, num_blocks[2], stride=2)
        self.fc = nn.Linear(64, num_classes)
        self._init_weights()

    def _make_layer(self, planes, blocks, stride):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = torch.nn.functional.avg_pool2d(out, out.size(3))
        out = out.view(out.size(0), -1)
        return self.fc(out)


def ResNet20(num_classes):
    return ResNetCIFAR([3, 3, 3], num_classes=num_classes)