import torch
import torch.nn as nn
import torch.optim as optim
import copy
from .baseline import BaseFLModel
from encoder_manager import create_encoder_manager

try:
    from peft import get_peft_model, LoraConfig, PeftModel
except ImportError:
    print("Warning: 'peft' library not found. FLoRA requires peft to run properly.")

class FLoRA(BaseFLModel):
    def __init__(self, cfg, device):
        super().__init__(cfg, device)
        m, t = cfg.get("model", {}), cfg.get("train", {})
        
        self.enc_mgr = create_encoder_manager(cfg, self.device)
        self.backbone = self.enc_mgr.model

        self.num_classes = int(m.get("num_classes", 100))
        self.feature_dim = self.enc_mgr.feature_dim
        self.head = nn.Linear(self.feature_dim, self.num_classes).to(self.device)
        
        self.lora_r = int(m.get("lora_r", 8))
        self.lora_alpha = int(m.get("lora_alpha", 16))
        self.lora_dropout = float(m.get("lora_dropout", 0.05))
        self.target_modules = m.get("target_modules", ["qkv"]) 
        self.heterogeneous_ranks = m.get("heterogeneous_ranks", []) 
        
        self.lr = float(t.get("lr", 3e-4))
        self.epochs = int(t.get("local_epochs", 1))
        self.bs = int(cfg.get("data", {}).get("batch_size", 16))
        self.input_type = "images" if self.enc_mgr.encoder_type in ["vit", "resnet", "convnext", "efficientnet", "dinov2"] else "text"

    def get_requirements(self):
        return {"input_type": self.input_type}

    def init_global(self, enc_info=None):
        return {
            "model": {k: v.cpu().clone() for k, v in self.backbone.state_dict().items()},
            "head": self.head.state_dict()
        }

    def client_update(self, global_state, client_data, round_idx, enc_mgr=None, client_id=0):
        self.backbone.load_state_dict(global_state["model"], strict=False)
        self.head.load_state_dict(global_state["head"])
        self.head.train()

        if self.heterogeneous_ranks:
            my_rank = self.heterogeneous_ranks[client_id % len(self.heterogeneous_ranks)]
        else:
            my_rank = self.lora_r
        my_alpha = 2 * my_rank
        
        peft_config = LoraConfig(
            r=my_rank,
            lora_alpha=my_alpha,
            target_modules=self.target_modules,
            lora_dropout=self.lora_dropout,
            bias="none",
        )
        
        peft_model = get_peft_model(self.backbone, peft_config)
        peft_model.train()

        trainable_params = list(peft_model.parameters()) + list(self.head.parameters())
        opt = optim.AdamW(trainable_params, lr=self.lr)
        loss_fn = nn.CrossEntropyLoss()

        loader = self._as_loader(client_data, shuffle=True, batch_size=self.bs)

        for _ in range(self.epochs):
            for xb, yb in loader:
                xb, yb = xb.to(self.device), yb.to(self.device)
                opt.zero_grad()
                
                features = peft_model(xb)
                if isinstance(features, dict): features = features["logits"]
                elif isinstance(features, tuple): features = features[0]
                
                logits = self.head(features)
                
                loss = loss_fn(logits, yb)
                loss.backward()
                opt.step()

        adapter_state = {k: v.cpu().clone() for k, v in peft_model.state_dict().items() if "lora_" in k}
        head_state = {k: v.cpu().clone() for k, v in self.head.state_dict().items()}
        
        try: peft_model.unload()
        except: pass

        num_samples = len(client_data[1])
        
        return {
            "lora": adapter_state,
            "head": head_state,
            "meta": {"rank": my_rank, "alpha": my_alpha}
        }, {"scalar": num_samples}

    @torch.no_grad()
    def evaluate(self, global_state, testset, enc_mgr=None):
        self.backbone.load_state_dict(global_state["model"], strict=False)
        self.head.load_state_dict(global_state["head"])
        self.backbone.eval()
        self.head.eval()
        
        loader = self._as_loader(testset, shuffle=False, batch_size=self.bs)
        loss_fn = nn.CrossEntropyLoss()
        tot_loss, corr, cnt = 0.0, 0, 0
        
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            
            features = self.backbone(xb)
            if isinstance(features, dict): features = features["logits"]
            elif isinstance(features, tuple): features = features[0]
            
            logits = self.head(features)
            
            tot_loss += loss_fn(logits, yb).item()
            corr += (logits.argmax(1) == yb).sum().item()
            cnt += yb.numel()
            
        return tot_loss / max(1, len(loader)), 100.0 * corr / max(1, cnt)