import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss

from .pfpt_core_model import PromptedEncoder
from .baseline import BaseFLModel
from encoder_manager import create_encoder_manager

class PFPT(BaseFLModel):
    def __init__(self, cfg, device):
        super().__init__(cfg, device)
        m, t = cfg.get("model", {}), cfg.get("train", {})
        
        self.enc_mgr = create_encoder_manager(cfg, self.device)
        encoder_model = self.enc_mgr.model 
        
        self.model = PromptedEncoder(
            encoder_model=encoder_model,
            num_tokens=int(m.get("num_tokens", 10)),
            num_classes=int(m.get("num_classes", 100))
        ).to(self.device)

        self.lr = float(t.get("lr", 5e-4))
        self.epochs = int(t.get("local_epochs", 5))
        self.bs = int(cfg["data"].get("batch_size", 16))

    def get_requirements(self):
        return {"input_type": "images"}

    def init_global(self, enc_info=None):
        trainable_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items() if k in self.model.trainable_keys}
        return {"trainable": trainable_state}

    def client_update(self, global_state, client_data, round_idx, enc_mgr=None):
        self.model.load_state_dict(global_state["trainable"], strict=False)
        self.model.train()
        
        loader = self._as_loader(client_data, shuffle=True, batch_size=self.bs)
        
        opt = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()), 
            lr=self.lr,
            betas=(0.9, 0.98), 
            eps=1e-6           
        )
        ce_loss = CrossEntropyLoss()
        
        for _ in range(self.epochs):
            for xb, yb in loader:
                xb, yb = xb.to(self.device), yb.to(self.device)
                opt.zero_grad(set_to_none=True)
                logits = self.model(xb)
                loss = ce_loss(logits, yb)
                loss.backward()
                opt.step()
        
        updated_params = {k: v.cpu().clone() for k, v in self.model.state_dict().items() if k in self.model.trainable_keys}
        num_samples = len(client_data[1])
        return {"trainable": updated_params}, {"trainable": {"scalar": num_samples}}

    @torch.no_grad()
    def evaluate(self, global_state, testset, enc_mgr=None):
        self.model.load_state_dict(global_state["trainable"], strict=False)
        self.model.eval()
        
        loader = self._as_loader(testset, shuffle=False, batch_size=self.bs)
        ce_loss = CrossEntropyLoss()
        total_loss, correct, count = 0.0, 0, 0
        
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            logits = self.model(xb)
            total_loss += ce_loss(logits, yb).item()
            correct += (logits.argmax(1) == yb).sum().item()
            count += yb.numel()
            
        return total_loss / max(1, len(loader)), 100.0 * correct / max(1, count)