
"""Implement the neural network models and training functions."""
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.optim import SGD, Optimizer, AdamW, Adam
from torch.utils.data import DataLoader
from torchvision.models import vgg11, mobilenet_v2   # PyTorch ≥2.0
import numpy as np

from utils_helper import reconstruct_parameters, segment_resnet_parameters, flatten_resnet_parameters

import torchvision.models as M
import torch.nn as nn

def make_resnet(model_name="resnet34", num_classes=200, pretrained=True):
    if model_name == "resnet34":
        model = M.resnet34(weights=M.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)
    elif model_name == "resnet50":
        model = M.resnet50(weights=M.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
    else:
        raise ValueError

    in_feats = model.fc.in_features
    model.fc = nn.Linear(in_feats, num_classes)
    return model

#ResNet9
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(nn.Module):
    def __init__(
            self, 
            input_dim,
            hidden_dims, 
            num_classes):
        super(ResNet9, self).__init__()
        
        self.conv1 = conv_block(input_dim, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))
        
        self.classifier = nn.Sequential(nn.AdaptiveMaxPool2d((1,1)), 
                                        nn.Flatten(), 
                                        nn.Dropout(0.2),
                                        nn.Linear(512, num_classes))
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out

import torch.nn as nn

'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, 
            padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, 
            padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes, planes * self.expansion, 
                    kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(planes * self.expansion)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=3, num_classes=100, dropout_prob=0.0):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        # "CIFAR stem": 3×3 conv, stride=1
        self.conv1 = nn.Conv2d(
            in_channels, 64, kernel_size=3, 
            stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Optional dropout before final FC
        self.dropout = nn.Dropout(p=dropout_prob)  
        
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)         # only here if desired
        out = self.linear(out)
        return out

def ResNet18(in_channels=3, num_classes=100, dropout_prob=0.0):
    return ResNet(BasicBlock, [2, 2, 2, 2],
                  in_channels=in_channels,
                  num_classes=num_classes,
                  dropout_prob=dropout_prob)

class CNN(nn.Module):
    """Implement a CNN model for CIFAR-10.

    Parameters
    ----------
    input_dim : int
        The input dimension for classifier.
    hidden_dims : List[int]
        The hidden dimensions for classifier.
    num_classes : int
        The number of classes in the dataset.
    """

    def __init__(self, input_dim, hidden_dims, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], num_classes)

    def forward(self, x):
        """Implement forward pass."""
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class CNN500k(nn.Module):
    def __init__(self, num_channels, image_size, num_classes):
        super().__init__()
        
        self.layer_stack = nn.Sequential(
            nn.Conv2d(num_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        
            nn.Flatten(),
            nn.Linear(32 * int(image_size/8) * int(image_size/8), 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.layer_stack(x)


class CNNMnist(nn.Module):
    """Implement a CNN model for MNIST and Fashion-MNIST.

    Parameters
    ----------
    input_dim : int
        The input dimension for classifier.
    hidden_dims : List[int]
        The hidden dimensions for classifier.
    num_classes : int
        The number of classes in the dataset.
    """

    def __init__(self, input_dim, hidden_dims, num_classes) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], num_classes)

    def forward(self, x):
        """Implement forward pass."""
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        x = x.view(-1, 16 * 4 * 4)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ScaffoldOptimizer(SGD):
    """Implements SGD optimizer step function as defined in the SCAFFOLD paper."""

    def __init__(self, grads, step_size, momentum, weight_decay):
        super().__init__(
            grads, lr=step_size, momentum=momentum, weight_decay=weight_decay
        )
        
    def step_custom(self, server_cv, client_cv):
            """Implement the custom step function fo SCAFFOLD."""
            # y_i = y_i - \eta * (g_i + c - c_i)  -->
            # y_i = y_i - \eta*(g_i + \mu*b_{t}) - \eta*(c - c_i)
            self.step()
            for group in self.param_groups:
                for par, s_cv, c_cv in zip(group["params"], server_cv, client_cv):
                    # Move tensors to the same device as parameters
                    s_cv = s_cv.to(par.device)
                    c_cv = c_cv.to(par.device)
                    par.data.add_(s_cv - c_cv, alpha=-group["lr"])


def train_scaffold(
    net: nn.Module,
    trainloader: DataLoader,
    device: torch.device,
    epochs: int,
    learning_rate: float,
    momentum: float,
    weight_decay: float,
    server_cv: torch.Tensor,
    client_cv: torch.Tensor,
) -> None:
    # pylint: disable=too-many-arguments
    """Train the network on the training set using SCAFFOLD.

    Parameters
    ----------
    net : nn.Module
        The neural network to train.
    trainloader : DataLoader
        The training set dataloader object.
    device : torch.device
        The device on which to train the network.
    epochs : int
        The number of epochs to train the network.
    learning_rate : float
        The learning rate.
    momentum : float
        The momentum for SGD optimizer.
    weight_decay : float
        The weight decay for SGD optimizer.
    server_cv : torch.Tensor
        The server's control variate.
    client_cv : torch.Tensor
        The client's control variate.
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = ScaffoldOptimizer(
        net.parameters(), learning_rate, momentum, weight_decay
    )
    net.train()
    for _ in range(epochs):
        net = _train_one_epoch_scaffold(
            net, trainloader, device, criterion, optimizer, server_cv, client_cv
        )


def _train_one_epoch_scaffold(
    net: nn.Module,
    trainloader: DataLoader,
    device: torch.device,
    criterion: nn.Module,
    optimizer: ScaffoldOptimizer,
    server_cv: torch.Tensor,
    client_cv: torch.Tensor,
) -> nn.Module:
    # pylint: disable=too-many-arguments
    """Train the network on the training set for one epoch."""
    for data, target in trainloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = net(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step_custom(server_cv, client_cv)
    return net

from torch.amp import autocast, GradScaler

import torch
import torch.nn as nn
from typing import Tuple

def _bn_train_dropout_eval(net: nn.Module):
    # Use batch stats for BN (train), but disable dropout noise.
    for m in net.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
            m.eval()
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            m.train()

def train_fedavg(
    net: nn.Module,
    trainloader: DataLoader,
    device: torch.device,
    epochs: int,
    learning_rate: float,
    momentum: float,
    weight_decay: float,
) -> None:
    # pylint: disable=too-many-arguments
    """Train the network on the training set using FedAvg.

    Parameters
    ----------
    net : nn.Module
        The neural network to train.
    trainloader : DataLoader
        The training set dataloader object.
    device : torch.device
        The device on which to train the network.
    epochs : int
        The number of epochs to train the network.
    learning_rate : float
        The learning rate.
    momentum : float
        The momentum for SGD optimizer.
    weight_decay : float
        The weight decay for SGD optimizer.

    Returns
    -------
    None
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(
        net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay
    )
    net.train()
    scaler = GradScaler(device='cuda')
    for _ in range(epochs):
        net = _train_one_epoch(net, trainloader, device, criterion, optimizer, scaler)

def _train_one_epoch(
    net: nn.Module,
    trainloader: DataLoader,
    device: torch.device,
    criterion: nn.Module,
    optimizer: Optimizer,
    scaler: GradScaler,
) -> nn.Module:
    """Train the network on the training set for one epoch with AMP."""
    net.train()
    running_loss = 0.0

    for data, target in trainloader:
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        with autocast(device_type='cuda'): 
            output = net(data)
            loss = criterion(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    return net

def set_lr_by_round(optimizer, base_lr, round_num, total_rounds, min_lr=1e-6):
    lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + np.cos(np.pi * round_num / total_rounds))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def train_fedavg_adam(
    net: nn.Module,
    trainloader: DataLoader,
    device: torch.device,
    epochs: int = 5,                       
    global_round: int = 0,
    lr: float   = 1.95e-05,                   
    weight_decay: float = 0.05,
    total_rounds: int = 150,
) -> None:
    """
    Local training with AdamW + warm-up + cosine.
    Called once per client before each global round.
    """

    # ------ optimiser -------------------------------------------------------
    # (optionally exclude biases & LayerNorm from weight decay)
    decay, no_decay = [], []
    for n, p in net.named_parameters():
        if not p.requires_grad:
            continue
        (no_decay if p.ndim == 1 or n.endswith('.bias') or n.endswith('.weight_g') or \
   'pos_embed' in n or 'cls_token' in n else decay).append(p)

    param_groups = [
        {"params": decay,     "weight_decay": weight_decay},
        {"params": no_decay,  "weight_decay": 0.0},
    ]

    initial_lr = 3e-4
    optimizer = AdamW(param_groups, lr=initial_lr, betas=(0.9,0.999))
    # set_lr_by_round(optimizer, initial_lr, global_round, total_rounds, min_lr=1e-6)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs*len(trainloader))

    # ------ loss + AMP scaffold ---------------------------------------------
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler(device='cuda')

    net.train()
    for _ in range(epochs):
        for data, target in trainloader:
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad(set_to_none=True)
            with autocast(device_type=device.type):
                logits = net(data)
                loss   = criterion(logits, target)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()  
            sched.step()
    return net.state_dict()

def test(
    net: nn.Module, testloader: DataLoader, device: torch.device
) -> Tuple[float, float]:
    """Evaluate the network on the test set.

    Parameters
    ----------
    net : nn.Module
        The neural network to evaluate.
    testloader : DataLoader
        The test set dataloader object.
    device : torch.device
        The device on which to evaluate the network.

    Returns
    -------
    Tuple[float, float]
        The loss and accuracy of the network on the test set.
    """
    # print("Testing the model...")
    criterion = nn.CrossEntropyLoss(reduction="sum")
    net.eval()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            output = net(data)
            loss += criterion(output, target).item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    loss = loss / total
    acc = correct / total
    return loss, acc