import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np


# --- Model: ResNet18 for CIFAR (Small kernel) ---
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)


# --- Dataset Manager ---
class DatasetHandler:
    def __init__(self, dataset_name='cifar10', num_agents=16, batch_size=128, max_samples=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_agents = num_agents

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        if dataset_name.lower() == 'cifar10':
            train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
            self.test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
            self.num_classes = 10
        elif dataset_name.lower() == 'cifar100':
            train_set = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
            self.test_set = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
            self.num_classes = 100
        else:
            raise ValueError("Dataset must be cifar10 or cifar100")

        total_size = len(train_set)
        indices = np.random.permutation(total_size)
        if max_samples and max_samples < total_size:
            indices = indices[:max_samples]

        samples_per_agent = len(indices) // num_agents
        self.agent_loaders = {}
        for i in range(num_agents):
            idx = indices[i * samples_per_agent: (i + 1) * samples_per_agent]
            self.agent_loaders[i] = DataLoader(Subset(train_set, idx), batch_size=batch_size, shuffle=True,
                                               num_workers=0)

        self.test_loader = DataLoader(self.test_set, batch_size=256, shuffle=False, num_workers=2)

    def get_train_loader(self, agent_id):
        return self.agent_loaders[agent_id]

    def get_test_loader(self):
        return self.test_loader


# --- Communication Networks ---
class FastComNetworkPyTorch:
    """Chebyshev Accelerated Communication (used by DOC2S)"""

    def __init__(self, num_agents, device):
        self.W = torch.full((num_agents, num_agents), 1.0 / num_agents, device=device)
        # Calc eta for Chebyshev
        eigenvalues = torch.linalg.eigvals(self.W).real
        eigenvalues = torch.sort(eigenvalues, descending=True).values
        beta = eigenvalues[1].item()
        self.eta = (1 - np.sqrt(1 - beta ** 2)) / (1 + np.sqrt(1 - beta ** 2)) if beta < 1 else 0

    def propagate(self, tensor_list, R):
        Z = torch.stack(tensor_list)
        Z_prev = Z.clone()
        for _ in range(R):
            Z_m = torch.matmul(self.W, Z)
            Z_new = (1 + self.eta) * Z_m - self.eta * Z_prev
            Z_prev = Z
            Z = Z_new
        return [row for row in Z]

    def get_average(self, tensor_list):
        return torch.mean(torch.stack(tensor_list), dim=0)


class ComNetworkPyTorch:
    """Standard Gossip Communication (used by MEDOL, DGFM)"""

    def __init__(self, num_agents, device):
        # Ring topology or Fully Connected? Using Fully Connected (Avg) for simplicity to match DOC2S baseline
        # Or Ring as per MEDOL paper? Let's use Ring to be consistent with MEDOL defaults if preferred,
        # but for fair comparison usually same topology is best.
        # Let's use 1/N (All-to-All) to ensure baseline fairness, or Ring if requested.
        # User's SVM 0th.py uses Ring for MEDOL/DGFM. Let's implement Ring matrix.
        self.device = device
        self.W = torch.zeros((num_agents, num_agents), device=device)
        # Create Ring Matrix
        for i in range(num_agents):
            self.W[i, i] = 1 / 3
            self.W[i, (i - 1) % num_agents] = 1 / 3
            self.W[i, (i + 1) % num_agents] = 1 / 3

    def propagate(self, tensor_list, R=1):
        Z = torch.stack(tensor_list)
        for _ in range(R):
            Z = torch.matmul(self.W, Z)
        return [row for row in Z]

    def get_average(self, tensor_list):
        return torch.mean(torch.stack(tensor_list), dim=0)


# --- Agent ---
class ResNetAgent:
    def __init__(self, agent_id, num_classes, lr=0.01, D=1.0, num_agents=16, device='cpu'):
        self.id = agent_id
        self.lr = lr
        self.D = D
        self.num_agents = num_agents
        self.device = device

        self.model = ResNet18(num_classes=num_classes).to(device)
        self.criterion = nn.CrossEntropyLoss()

        self.flat_params_size = sum(p.numel() for p in self.model.parameters())

        # Buffers
        self.action = torch.zeros(self.flat_params_size, device=device)  # For DOC2S/MEDOL
        self.tracker = torch.zeros(self.flat_params_size, device=device)  # For DGFM (Gradient Tracking)
        self.prev_grad = torch.zeros(self.flat_params_size, device=device)  # For DGFM

    def get_flat_params(self):
        return torch.cat([p.data.view(-1) for p in self.model.parameters()])

    def set_flat_params(self, flat_params):
        offset = 0
        for p in self.model.parameters():
            numel = p.numel()
            p.data.copy_(flat_params[offset:offset + numel].view_as(p))
            offset += numel

    # Getters/Setters for buffers
    def get_action(self):
        return self.action

    def set_action(self, a):
        self.action.copy_(a)

    def initialize_action(self):
        self.action.zero_()

    def get_tracker(self):
        return self.tracker

    def set_tracker(self, t):
        self.tracker.copy_(t)

    def get_prev_grad(self):
        return self.prev_grad

    def set_prev_grad(self, g):
        self.prev_grad.copy_(g)

    # Compute Gradient (Standard Backprop)
    def compute_gradient(self, data_batch):
        data, target = data_batch
        data, target = data.to(self.device), target.to(self.device)

        self.model.train()
        self.model.zero_grad()
        output = self.model(data)
        loss = self.criterion(output, target)
        loss.backward()

        grads = []
        for p in self.model.parameters():
            if p.grad is not None:
                grads.append(p.grad.view(-1))
            else:
                grads.append(torch.zeros_as(p).view(-1))
        return torch.cat(grads)

    def get_test_loss(self, test_loader):
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                outputs = self.model(data)
                loss = self.criterion(outputs, target)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
        return test_loss / len(test_loader), 100. * correct / total