from dataclasses import dataclass
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
import torchvision
import torchvision.transforms as T
from models import ResNet18

def _split_classes_into_K_groups(K:int) -> List[List[int]]:
    all_classes = list(range(10))
    groups = []
    start = 0
    for i in range(K):
        end = round((i+1)*10.0/K)
        groups.append(all_classes[start:end])
        start = end
    return groups

class IncrementalCollectorMulti:
    def __init__(self, pools: List[np.ndarray]):
        self.K = len(pools)
        self.pools = [p.copy() for p in pools]
        self.ptrs  = [0]*self.K
        self.selected = [[] for _ in range(self.K)]
        for k in range(self.K):
            np.random.shuffle(self.pools[k])  

    def ensure_size(self, q_targets: List[int]) -> List[int]:
        added = [0]*self.K
        for k in range(self.K):
            curr = len(self.selected[k])
            need = max(0, q_targets[k] - curr)
            if need == 0: 
                continue
            start = self.ptrs[k]
            end   = min(start + need, len(self.pools[k]))
            take  = self.pools[k][start:end].tolist()
            self.ptrs[k] = end
            self.selected[k].extend(take)
            added[k] = len(take)
        return added

    def indices_flat(self) -> List[int]:
        out = []
        for k in range(self.K):
            out.extend(self.selected[k])
        return out

    def current_q_vec(self) -> List[int]:
        return [len(self.selected[k]) for k in range(self.K)]

@dataclass
class TrainCfg:
    epochs:int=50; batch_size:int=128; lr:float=0.1; momentum:float=0.9; weight_decay:float=5e-4
    use_cosine:bool=True; milestones:tuple=(30,40); gamma:float=0.1; num_workers:int=0

class CIFAR10ResNetOracleMulti:
    def __init__(self, device: str, cfg: TrainCfg, K: int):
        self.device,self.cfg = device,cfg
        mean,std=(0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)
        self.tf_train=T.Compose([T.RandomCrop(32,padding=4), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean,std)])
        self.tf_test =T.Compose([T.ToTensor(), T.Normalize(mean,std)])
        self.train_full=torchvision.datasets.CIFAR10("./data/cifar10",train=True,download=True,transform=self.tf_train)
        self.test_set  =torchvision.datasets.CIFAR10("./data/cifar10",train=False,download=True,transform=self.tf_test)

        targets = np.array(self.train_full.targets)
        class_groups = _split_classes_into_K_groups(K)
        pools = [np.where(np.isin(targets, g))[0] for g in class_groups]
        self.collector = IncrementalCollectorMulti(pools)

    def ensure_collected(self, q_vec: List[int]) -> List[int]:
        return self.collector.ensure_size(q_vec)

    def current_q_vec(self) -> List[int]:
        return self.collector.current_q_vec()

    def _train_eval(self, idxs: List[int]) -> float:
        train_loader=DataLoader(Subset(self.train_full,idxs),batch_size=128,shuffle=True,
                                num_workers=self.cfg.num_workers,pin_memory=True)
        test_loader =DataLoader(self.test_set,batch_size=100,shuffle=False,
                                num_workers=self.cfg.num_workers,pin_memory=True)
        model = ResNet18().to(self.device)
        criterion=nn.CrossEntropyLoss()
        opt=optim.SGD(model.parameters(),lr=0.1,momentum=self.cfg.momentum,weight_decay=self.cfg.weight_decay)
        sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=self.cfg.epochs) if self.cfg.use_cosine \
            else torch.optim.lr_scheduler.MultiStepLR(opt,milestones=list(self.cfg.milestones),gamma=self.cfg.gamma)

        for _ in range(self.cfg.epochs):
            model.train()
            for x,y in train_loader:
                x,y=x.to(self.device),y.to(self.device)
                opt.zero_grad(); loss=criterion(model(x),y); loss.backward(); opt.step()

            sch.step()
        model.eval(); correct=0; total=0
        with torch.no_grad():
            for x,y in test_loader:
                x,y=x.to(self.device),y.to(self.device)
                pred=model(x).argmax(1); correct+=(pred==y).sum().item(); total+=y.size(0)
        return 100.0*correct/total

    def __call__(self, q_vec: List[int]) -> float:
        self.ensure_collected([int(x) for x in q_vec])
        
        return self._train_eval(self.collector.indices_flat()) 

def get_cifar10_oracle_multi(K: int, epochs=50, lr=0.1):
    device="cuda" if torch.cuda.is_available() else "cpu"
    return CIFAR10ResNetOracleMulti(device, TrainCfg(epochs=epochs, lr=lr), K=K)
