from dataclasses import dataclass
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
import torchvision
import torchvision.transforms as T

from models import ResNet18


def _split_classes_into_K_groups(num_classes: int, K: int) -> List[List[int]]:
    all_classes = list(range(num_classes))
    groups = []
    start = 0
    for i in range(K):
        end = round((i + 1) * float(num_classes) / K)
        groups.append(all_classes[start:end])
        start = end
    return groups


class IncrementalCollectorMulti:
    def __init__(self, pools: List[np.ndarray]):
        self.K = len(pools)
        self.pools = [p.copy() for p in pools]
        self.ptrs = [0] * self.K
        self.selected = [[] for _ in range(self.K)]
        for k in range(self.K):
            np.random.shuffle(self.pools[k])  

    def ensure_size(self, q_targets: List[int]) -> List[int]:
        added = [0] * self.K
        for k in range(self.K):
            curr = len(self.selected[k])
            need = max(0, q_targets[k] - curr)
            if need == 0:
                continue
            start = self.ptrs[k]
            end = min(start + need, len(self.pools[k]))
            take = self.pools[k][start:end].tolist()
            self.ptrs[k] = end
            self.selected[k].extend(take)
            added[k] = len(take)
        return added

    def indices_flat(self) -> List[int]:
        out = []
        for k in range(self.K):
            out.extend(self.selected[k])
        return out

    def current_q_vec(self) -> List[int]:
        return [len(self.selected[k]) for k in range(self.K)]


@dataclass
class TrainCfg:
    epochs: int = 50
    batch_size: int = 128
    lr: float = 0.1
    momentum: float = 0.9
    weight_decay: float = 5e-4
    use_cosine: bool = True
    milestones: tuple = (30, 40)
    gamma: float = 0.1
    num_workers: int = 0


class CIFAR100ResNetOracleMulti:
    def __init__(self, device: str, cfg: TrainCfg, K: int):
        self.device, self.cfg = device, cfg

        mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)

        self.tf_train = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean, std),
        ])
        self.tf_test = T.Compose([
            T.ToTensor(),
            T.Normalize(mean, std),
        ])

        self.train_full = torchvision.datasets.CIFAR100(
            "./data/cifar100", train=True, download=True, transform=self.tf_train
        )
        self.test_set = torchvision.datasets.CIFAR100(
            "./data/cifar100", train=False, download=True, transform=self.tf_test
        )

        targets = np.array(self.train_full.targets)  # 0..99
        class_groups = _split_classes_into_K_groups(num_classes=100, K=K)
        pools = [np.where(np.isin(targets, g))[0] for g in class_groups]
        self.collector = IncrementalCollectorMulti(pools)

    def ensure_collected(self, q_vec: List[int]) -> List[int]:
        return self.collector.ensure_size([int(x) for x in q_vec])

    def current_q_vec(self) -> List[int]:
        return self.collector.current_q_vec()

    def _train_eval(self, idxs: List[int]) -> float:
        train_loader = DataLoader(
            Subset(self.train_full, idxs),
            batch_size=128,
            shuffle=True,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
        )
        test_loader = DataLoader(
            self.test_set,
            batch_size=100,
            shuffle=False,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
        )


        model = ResNet18(num_classes=100).to(self.device)
        criterion = nn.CrossEntropyLoss()
        opt = optim.SGD(
            model.parameters(),
            lr=0.1,
            momentum=self.cfg.momentum,
            weight_decay=self.cfg.weight_decay,
        )

        sch = (
            torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.cfg.epochs)
            if self.cfg.use_cosine
            else torch.optim.lr_scheduler.MultiStepLR(
                opt, milestones=list(self.cfg.milestones), gamma=self.cfg.gamma
            )
        )

        for _ in range(self.cfg.epochs):
            model.train()
            for x, y in train_loader:
                x, y = x.to(self.device), y.to(self.device)
                opt.zero_grad()
                loss = criterion(model(x), y)
                loss.backward()
                opt.step()
            sch.step()

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(self.device), y.to(self.device)
                pred = model(x).argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        return 100.0 * correct / total

    def __call__(self, q_vec: List[int]) -> float:
        self.ensure_collected(q_vec)
        return self._train_eval(self.collector.indices_flat())


def get_cifar100_oracle_multi(K: int, epochs=50, lr=0.1):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return CIFAR100ResNetOracleMulti(device, TrainCfg(epochs=epochs, lr=lr), K=K)
