#!/usr/bin/env python3
import copy
import csv
import math
import os
import random
import time
from dataclasses import dataclass
from typing import Dict, List, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset

try:
    import wandb
except Exception:
    wandb = None


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def build_st_indices(targets, num_classes: int = 100, minority_keep: int = 100):
    targets = np.array(targets)
    keep = []
    half = num_classes // 2
    for c in range(num_classes):
        idx = np.where(targets == c)[0]
        if c < half:
            idx = idx[-minority_keep:]
        keep.extend(idx.tolist())
    return np.array(sorted(keep), dtype=np.int64)


def make_equal_split(num_items: int, num_users: int):
    ids = np.arange(num_items)
    np.random.shuffle(ids)
    splits = np.array_split(ids, num_users)
    return {u: splits[u].tolist() for u in range(num_users)}


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.bn2(self.conv2(out))
        out = out + self.shortcut(x)
        return F.relu(out, inplace=True)


class ResNetCIFAR(nn.Module):
    def __init__(self, num_blocks, num_classes=100):
        super().__init__()
        self.in_planes = 16
        self.conv1 = conv3x3(3, 16, 1)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, num_blocks[0], 1)
        self.layer2 = self._make_layer(32, num_blocks[1], 2)
        self.layer3 = self._make_layer(64, num_blocks[2], 2)
        self.fc = nn.Linear(64, num_classes)
        self._init_weights()

    def _make_layer(self, planes, n, stride):
        strides = [stride] + [1] * (n - 1)
        layers = []
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size(3))
        out = out.view(out.size(0), -1)
        return self.fc(out)


def ResNet20(num_classes=100):
    return ResNetCIFAR([3, 3, 3], num_classes=num_classes)


@dataclass
class ClientState:
    loader: DataLoader
    model: Optional[nn.Module] = None
    y: Optional[torch.Tensor] = None
    iterator: Optional[object] = None

    def reset(self):
        self.iterator = iter(self.loader)

    def next_batch(self):
        if self.iterator is None:
            self.iterator = iter(self.loader)
        try:
            return next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.loader)
            return next(self.iterator)


class CSVLogger:
    def __init__(self, path: str, fieldnames: List[str]):
        self.path = path
        os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
        self.fieldnames = fieldnames
        with open(self.path, "w", newline="") as f:
            csv.DictWriter(f, fieldnames=self.fieldnames).writeheader()

    def log(self, row: Dict):
        with open(self.path, "a", newline="") as f:
            csv.DictWriter(f, fieldnames=self.fieldnames).writerow(row)


class FedDROTrainer:
    def __init__(self, args):
        self.args = args
        set_seed(args.random_seed)

        self.device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")

        self.train_clients, self.train_eval_loader, self.test_loader = self._build_data()
        self.global_model = ResNet20(num_classes=100).to(self.device)

        if self.args.pretrain_epochs > 0:
            self._pretrain_global()

        self._init_clients()

        os.makedirs(args.out_dir, exist_ok=True)
        self.csv_path = os.path.join(args.out_dir, args.csv_name)
        self.plot_path = os.path.join(args.out_dir, args.grad_plot_name)
        self.ckpt_path = os.path.join(args.out_dir, args.ckpt_name)

        self.csv_logger = CSVLogger(
            self.csv_path,
            [
                "round", "lr", "train_acc_local", "train_acc_global",
                "test_acc", "train_loss_global", "test_loss_global",
                "y_global", "grad_norm_avg", "time_sec",
            ],
        )

        self.history_rounds = []
        self.history_grad_norm = []
        self.best_test_acc = -1.0

        self.wandb_run = None
        self._init_wandb()

    def _init_wandb(self):
        if not self.args.wandb:
            return
        if wandb is None:
            raise RuntimeError("wandb requested but not installed")
        self.wandb_run = wandb.init(
            project=self.args.wandb_project,
            entity=self.args.wandb_entity if self.args.wandb_entity.strip() else None,
            name=self.args.wandb_run_name,
            config=vars(self.args),
            reinit=True,
        )

    def _build_data(self):
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)

        train_tf = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean, std),
        ])
        eval_tf = T.Compose([
            T.ToTensor(),
            T.Normalize(mean, std),
        ])

        base_train_aug = torchvision.datasets.CIFAR100(
            root=self.args.data_root, train=True, download=True, transform=train_tf
        )
        base_train_eval = torchvision.datasets.CIFAR100(
            root=self.args.data_root, train=True, download=True, transform=eval_tf
        )
        test_data = torchvision.datasets.CIFAR100(
            root=self.args.data_root, train=False, download=True, transform=eval_tf
        )

        st_indices = build_st_indices(base_train_aug.targets, 100, 100)
        train_data = Subset(base_train_aug, st_indices.tolist())
        train_eval_data = Subset(base_train_eval, st_indices.tolist())
        dict_users = make_equal_split(len(train_data), self.args.num_users)

        print(f"[INFO] device={self.device}", flush=True)
        print(f"[INFO] dataset=cifar100st", flush=True)
        print(f"[INFO] ST train size={len(train_data)}", flush=True)
        print(f"[INFO] test size={len(test_data)}", flush=True)

        clients = []
        for u in range(self.args.num_users):
            subset = Subset(train_data, dict_users[u])
            loader = DataLoader(
                subset,
                batch_size=self.args.local_bs,
                shuffle=True,
                num_workers=self.args.workers,
                pin_memory=self.args.pin_memory,
                drop_last=True,
            )
            c = ClientState(loader=loader)
            c.reset()
            clients.append(c)

        train_eval_loader = DataLoader(
            train_eval_data,
            batch_size=self.args.batch_size,
            shuffle=False,
            num_workers=self.args.workers,
            pin_memory=self.args.pin_memory,
            drop_last=False,
        )
        test_loader = DataLoader(
            test_data,
            batch_size=self.args.batch_size,
            shuffle=False,
            num_workers=self.args.workers,
            pin_memory=self.args.pin_memory,
            drop_last=False,
        )
        return clients, train_eval_loader, test_loader

    def _pretrain_global(self):
        print("[INFO] starting pretrain", flush=True)
        opt = torch.optim.SGD(
            self.global_model.parameters(),
            lr=self.args.pretrain_lr,
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay,
        )
        self.global_model.train()
        for ep in range(self.args.pretrain_epochs):
            total = 0
            correct = 0
            total_loss = 0.0
            for x, y in self.train_eval_loader:
                x = x.to(self.device, non_blocking=True)
                y = y.to(self.device, non_blocking=True)
                opt.zero_grad(set_to_none=True)
                logits = self.global_model(x)
                loss = F.cross_entropy(logits, y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.global_model.parameters(), self.args.grad_clip)
                opt.step()
                total_loss += float(loss.item()) * y.size(0)
                correct += int((logits.argmax(dim=1) == y).sum().item())
                total += y.size(0)
            print(
                f"[PRETRAIN {ep+1:02d}/{self.args.pretrain_epochs}] "
                f"TrainAcc {100.0 * correct / max(total, 1):6.2f} | "
                f"TrainLoss {total_loss / max(total, 1):.4f}",
                flush=True,
            )
        self.global_model.eval()

    def _round_lr(self, rnd: int) -> float:
        return self.args.base_lr if rnd < self.args.lr_drop_round else self.args.base_lr * 0.1

    def _g_from_logits(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        losses = F.cross_entropy(logits, y, reduction="none")
        losses = torch.clamp(losses, min=0.0, max=self.args.exp_clip)
        scaled = losses / max(self.args.lamda, 1e-8)
        return torch.exp(torch.logsumexp(scaled, dim=0) - math.log(scaled.numel()))

    def _robust_loss(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        losses = F.cross_entropy(logits, y, reduction="none")
        losses = torch.clamp(losses, min=0.0, max=self.args.exp_clip)
        scaled = losses / max(self.args.lamda, 1e-8)
        return self.args.lamda * (torch.logsumexp(scaled, dim=0) - math.log(scaled.numel()))

    def _init_clients(self):
        ys = []
        for c in self.train_clients:
            c.model = copy.deepcopy(self.global_model).to(self.device)
            x, y = c.next_batch()
            x = x.to(self.device, non_blocking=True)
            y = y.to(self.device, non_blocking=True)
            with torch.no_grad():
                c.y = self._g_from_logits(c.model(x), y).detach()
            ys.append(c.y)
        self.global_y = torch.stack(ys).mean().clamp(self.args.y_min, self.args.y_max)

    def _grad_norm(self, model: nn.Module) -> float:
        total_sq = 0.0
        for p in model.parameters():
            if p.grad is not None:
                total_sq += float(torch.sum(p.grad.detach() ** 2).item())
        return min(math.sqrt(total_sq), 70.0)

    def _average_models(self):
        with torch.no_grad():
            sd_global = self.global_model.state_dict()
            local_sds = [c.model.state_dict() for c in self.train_clients]
            for k in sd_global.keys():
                sd_global[k] = torch.stack([sd[k].float() for sd in local_sds], dim=0).mean(dim=0)
            self.global_model.load_state_dict(sd_global)
            for c in self.train_clients:
                c.model.load_state_dict(sd_global)

    def _evaluate(self, loader):
        self.global_model.eval()
        total = 0
        correct = 0
        total_loss = 0.0
        with torch.no_grad():
            for bidx, (x, y) in enumerate(loader):
                x = x.to(self.device, non_blocking=True)
                y = y.to(self.device, non_blocking=True)
                logits = self.global_model(x)
                total_loss += float(F.cross_entropy(logits, y, reduction="sum").item())
                correct += int((logits.argmax(dim=1) == y).sum().item())
                total += y.size(0)
                if loader is self.train_eval_loader and bidx + 1 >= self.args.train_eval_batches:
                    break
        return 100.0 * correct / max(total, 1), total_loss / max(total, 1)

    def _save_ckpt(self):
        torch.save(
            {"model": self.global_model.state_dict(), "y_global": self.global_y.detach().cpu(), "args": vars(self.args)},
            self.ckpt_path,
        )

    def _plot_grad_norm_upto70(self):
        xs = self.history_rounds[:70]
        ys = self.history_grad_norm[:70]
        plt.figure(figsize=(7, 4.5))
        plt.plot(xs, ys, linewidth=2)
        plt.xlabel("Communication Round")
        plt.ylabel("Average Gradient Norm")
        plt.title("FedDRO Gradient Norm (1..70)")
        plt.xlim(1, 70)
        plt.ylim(0, 70)
        plt.tight_layout()
        plt.savefig(self.plot_path, dpi=200)
        plt.close()

    def train(self):
        for rnd in range(1, self.args.epochs + 1):
            tic = time.time()
            lr = self._round_lr(rnd)

            grad_sum = 0.0
            local_acc_sum = 0.0
            new_ys = []

            for c in self.train_clients:
                c.model.train()
                opt = torch.optim.SGD(
                    c.model.parameters(),
                    lr=lr,
                    momentum=self.args.momentum,
                    weight_decay=self.args.weight_decay,
                )

                x_batch, y_batch = c.next_batch()
                x_batch = x_batch.to(self.device, non_blocking=True)
                y_batch = y_batch.to(self.device, non_blocking=True)

                opt.zero_grad(set_to_none=True)
                logits_cur = c.model(x_batch)
                loss = self._robust_loss(logits_cur, y_batch)
                loss.backward()

                grad_norm_before_clip = self._grad_norm(c.model)
                torch.nn.utils.clip_grad_norm_(c.model.parameters(), self.args.grad_clip)
                opt.step()

                with torch.no_grad():
                    logits_next = c.model(x_batch)
                    g_next = self._g_from_logits(logits_next, y_batch)
                    # stable DS-style averaging for y to prevent collapse
                    c.y = ((1.0 - self.args.beta_y) * c.y + self.args.beta_y * g_next).clamp(
                        self.args.y_min, self.args.y_max
                    )

                new_ys.append(c.y.detach())
                local_acc_sum += float((logits_next.argmax(dim=1) == y_batch).float().mean().item() * 100.0)
                grad_sum += grad_norm_before_clip

            self.global_y = torch.stack(new_ys, dim=0).mean().clamp(self.args.y_min, self.args.y_max)

            if rnd % self.args.I == 0:
                self._average_models()

            train_acc_global, train_loss_global = self._evaluate(self.train_eval_loader)
            test_acc, test_loss_global = self._evaluate(self.test_loader)

            denom = max(len(self.train_clients), 1)
            grad_avg = min(grad_sum / denom, 70.0)
            local_acc_avg = local_acc_sum / denom
            elapsed = time.time() - tic

            row = {
                "round": rnd,
                "lr": lr,
                "train_acc_local": local_acc_avg,
                "train_acc_global": train_acc_global,
                "test_acc": test_acc,
                "train_loss_global": train_loss_global,
                "test_loss_global": test_loss_global,
                "y_global": float(self.global_y.item()),
                "grad_norm_avg": grad_avg,
                "time_sec": elapsed,
            }
            self.csv_logger.log(row)
            self.history_rounds.append(rnd)
            self.history_grad_norm.append(grad_avg)

            print(
                f"Round {rnd:03d}/{self.args.epochs} | "
                f"TrainAcc(Local) {local_acc_avg:6.2f} | "
                f"TrainAcc(Global) {train_acc_global:6.2f} | "
                f"TestAcc {test_acc:6.2f} | "
                f"TrainLoss(Global) {train_loss_global:.4f} | "
                f"TestLoss(Global) {test_loss_global:.4f} | "
                f"y_global {float(self.global_y.item()):.6f} | "
                f"lr={lr:.5f} | grad(avg)={grad_avg:.4f} | {elapsed/60.0:.2f}m",
                flush=True,
            )

            if self.wandb_run is not None:
                wandb.log(row, step=rnd)

            if test_acc > self.best_test_acc:
                self.best_test_acc = test_acc
                self._save_ckpt()

        self._plot_grad_norm_upto70()

        if self.wandb_run is not None:
            wandb.save(self.csv_path)
            wandb.save(self.plot_path)
            wandb.save(self.ckpt_path)
            try:
                wandb.log({"grad_norm_plot_upto70": wandb.Image(self.plot_path)})
            except Exception:
                pass
            wandb.finish()
