from dataclasses import dataclass
from typing import Any
import os

from torch import nn
from torch.utils.data import DataLoader, Subset, get_worker_info
from torchvision import datasets, transforms, models

from utils.function import *
from train_classification.resnet import *

def seed_worker(worker_id):
    info = get_worker_info()
    base_seed = info.seed
    np.random.seed(base_seed % (2**32))
    random.seed(base_seed % (2**32))



@dataclass
class RunCfg:
    device: str = "cuda"
    seed: int = 42
    dataset: str = "cifar10"
    lr: float = 0.03
    weight_decay: float = 0.0
    batch_size_local: int = 128
    num_users: int = 10
    iterations: int = 3000
    eval_every: int = 20
    split_ratio: Optional[List[float]] = None
    topology: str = "static"
    use_wandb: bool = False
    entity: str = ""
    project: str = ""
    run_name: Optional[str] = None

def build_data(cfg: RunCfg) -> Dict[str, Any]:
    dataset = cfg.dataset.lower()
    project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    data_dir = os.path.join(project_root, "data")
    if dataset in ["cifar10", "cifar10_res"]::
        mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        train_t = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
            transforms.Normalize(mean, std),
        ])
        test_t  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        trainset = datasets.CIFAR10(root=data_dir, train=True,  download=True, transform=train_t)
        testset  = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_t)
        print("CIFAR10 DATASET")
    elif dataset in ("fashionmnist", "mnist"):
        mean, std = ((0.2860,), (0.3530,)) if dataset == "fashionmnist" else ((0.1307,), (0.3081,))
        t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        if dataset == "fashionmnist":
            trainset = datasets.FashionMNIST(root=data_dir, train=True,  download=True, transform=t)
            testset  = datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=t)
            print("fashion mnist DATASET")
        else:
            trainset = datasets.MNIST(root="../data", train=True,  download=True, transform=t)
            testset  = datasets.MNIST(root="../data", train=False, download=True, transform=t)
            print("mnist DATASET")
    else:
        raise ValueError(f"Unknown dataset: {cfg.dataset}")

    split_ratio = cfg.split_ratio if cfg.split_ratio is not None else [1.0] * cfg.num_users
    assert len(split_ratio) == cfg.num_users
    user_indices_list = iid_split_indices_ratio(trainset, split_ratio, seed=cfg.seed)

    def _user_loader(indices, seed_u):
        subset = Subset(trainset, indices)
        g = torch.Generator().manual_seed(seed_u)
        return DataLoader(
            subset, batch_size=cfg.batch_size_local, shuffle=True,
            generator=g, num_workers=2, pin_memory=True,
            worker_init_fn=seed_worker, persistent_workers=True
        )

    train_loaders = [_user_loader(user_indices_list[u], cfg.seed + u) for u in range(cfg.num_users)]
    test_loader = DataLoader(
            testset, batch_size=1024, shuffle=False, num_workers=2, pin_memory=True,
            worker_init_fn=seed_worker, persistent_workers=False
        )
    return {"train_loaders": train_loaders, "test_loader": test_loader}


def build_model(cfg: RunCfg) -> torch.nn.Module:
    save_path = os.path.join(os.path.dirname(__file__), "..", "initial_weights")
    save_path = os.path.abspath(save_path)
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
    if cfg.dataset == "mnist":
        model = MnistCnn().to(device)
        wpath = os.path.join(save_path, "MnistCnn_init.pth")
    elif cfg.dataset == "cifar10":
        model = VGG9().to(device)
        wpath = os.path.join(save_path, "VGG9_init.pth")
    elif cfg.dataset == "cifar10_res":
        model = ResNet18WithBN().to(device)
        wpath = os.path.join(save_path, "ResNet18_init.pth")
    else:
        model = FashionMnist_Cnn().to(device)
        wpath = os.path.join(save_path, "FashionMnist_Cnn_init.pth")

    if os.path.exists(wpath):
        state = torch.load(wpath, map_location=device)
        model.load_state_dict(state)
        print(f"[build_model] Loaded initial weights: {wpath}")
    else:
        print(f"[build_model] Warning: {wpath} not found, using random init")
    return model
