import torch
import torch.nn.functional as F
from src.core.base import BaseServer
from src.algorithms.server.supervised import SupervisedServer
from src.algorithms.utils import ce_loss
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
import numpy as np


class SCOMatchServer(BaseServer):

    def __init__(self, config, net_builder, train_loader, test_loader, logger):
        super().__init__(config, net_builder, train_loader, test_loader, logger)
    
    def train_step(self, optimizer, x_lb, y_lb):
        self.model.train()
        
        self.optimizer.zero_grad()
        logits_x_lb = self.model(x_lb)['logits']
        sup_loss = ce_loss(logits_x_lb, y_lb, reduction='mean')
        
        sup_loss.backward() 
        if self.clip_grad > 0:
            total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad)
            # print(f"Gradient norm: {total_norm}")
            
        optimizer.step() 
        
        lr = optimizer.param_groups[0]['lr']
        
        res_dict = {f"{self.mode}_train/s_loss": sup_loss.item(),
                    f"{self.mode}_train/lr": lr}

        self.res_dict = res_dict
        
        return res_dict



    @torch.no_grad()
    def evaluate(self, mode="warmup"):
        print(f">> eval round: {self.round}")
        print(f">> eval epochs: {self.epoch}")

        self.model.eval()
        print(f"== use eval ema: {self.use_ema}")
        if self.use_ema:
            self.ema.apply_shadow()

        total_loss = 0.0
        total_num = 0.0

        y_true = []
        y_pred = []

        ood_labels = []
        ood_scores_msp = []
        ood_scores_entropy = []
        ood_scores_energy = []
        ood_scores_k1 = []  # K+1 confidence score

        for data in self.test_loader:
            x = data['x_lb']
            y = data['y_lb']

            x = {k: v.to(self.device) for k, v in x.items()} if isinstance(x, dict) else x.to(self.device)
            y = y.to(self.device)

            outputs = self.model(x)
            logits = outputs['logits']

            preds = torch.argmax(logits, dim=1)

            in_idx = torch.where(y < self.num_classes)[0]
            if len(in_idx) > 0:
                loss = ce_loss(logits[in_idx], y[in_idx], reduction='mean')
                total_loss += loss.item() * in_idx.shape[0]
                total_num += in_idx.shape[0]

                y_true.extend(y[in_idx].cpu().tolist())
                y_pred.extend(preds[in_idx].cpu().tolist())

            # --- OOD scores ---
            probs = F.softmax(logits, dim=1)
            
            # Debug: logits shape 확인
            # print(f"Debug: logits shape = {logits.shape}, num_classes = {self.num_classes}")
            
            # (K+1) 확률 기반 AUROC - 안전하게 처리
            if logits.shape[1] > self.num_classes:  # K+1 클래스가 있는지 확인
                p_ood = probs[:, self.num_classes]  # K+1번째 클래스 확률
            else:
                # K+1 클래스가 없으면 최소 확률을 사용
                p_ood = 1 - probs[:, :self.num_classes].max(dim=1).values
            
            # MSP: ID K개에 대해서만
            msp_id = probs[:, :self.num_classes].max(dim=1).values

            # Entropy: 높을수록 OOD (불확실성 높음)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
            
            # Energy: 낮을수록 OOD (불확실성 높음) -> 음수를 붙여서 높을수록 OOD로 변환
            energy = torch.logsumexp(logits, dim=1)  # 음수 제거
            
            # OOD 스코어들 (모두 높을수록 OOD)
            ood_scores_msp.extend((1 - msp_id).cpu().tolist())  # 1-MSP: 높을수록 OOD
            ood_scores_entropy.extend(entropy.cpu().tolist())   # Entropy: 높을수록 OOD  
            ood_scores_energy.extend((-energy).cpu().tolist())  # -Energy: 높을수록 OOD
            ood_labels.extend((y >= self.num_classes).int().cpu().tolist())
            # K+1 확률 기반 OOD 스코어: 높을수록 OOD
            ood_scores_k1.extend(p_ood.cpu().tolist())

        top1 = accuracy_score(y_true, y_pred) if len(y_true) > 0 else 0.0

        if self.num_classes <= 10 and len(y_true) > 0:
            cf_mat = confusion_matrix(y_true, y_pred, normalize='true')
            print("confusion matrix:\n" + np.array_str(cf_mat))

        try:
            auroc_msp = roc_auc_score(ood_labels, ood_scores_msp)
            auroc_entropy = roc_auc_score(ood_labels, ood_scores_entropy)
            auroc_energy = roc_auc_score(ood_labels, ood_scores_energy)
            auroc_k1 = roc_auc_score(ood_labels, ood_scores_k1)
            
            # Debug: auroc_k1 값 확인
            # print(f"Debug: auroc_k1 = {auroc_k1}, ood_labels unique = {set(ood_labels)}")
            # print(f"Debug: ood_scores_k1 range = [{min(ood_scores_k1):.4f}, {max(ood_scores_k1):.4f}]")
            
        except ValueError as e:
            print(f"Debug: AUROC calculation error: {e}")
            auroc_msp = auroc_entropy = auroc_energy = auroc_k1 = 0.0

        if self.use_ema:
            self.ema.restore()
        self.model.train()

        eval_dict = {
            mode + '/loss': total_loss / total_num if total_num > 0 else 0.0,
            mode + '/top-1-acc': top1,
            mode + '/auroc_msp': auroc_msp,
            mode + '/auroc_entropy': auroc_entropy,
            mode + '/auroc_energy': auroc_energy,
            mode + '/auroc_k1': auroc_k1, 
            mode + '/round': self.round
        }

        print(f"Debug: eval_dict keys = {list(eval_dict.keys())}")
        print(f"Debug: auroc_k1 in eval_dict = {eval_dict.get(mode + '/auroc_k1')}")
        print(eval_dict)

        # -------- 모드별 처리 및 모드별 wandb 로깅 --------
        def _safe_log_to_wandb(payload: dict):
            if not self.use_wandb:
                return
            log_payload = {}
            for k, v in payload.items():
                if hasattr(v, "item"):
                    v = v.item()
                elif isinstance(v, (np.floating, np.integer)):
                    v = float(v)
                log_payload[k] = v
            self.run.log(log_payload, step=int(self.epoch + self.round))

        if mode == 'agg':
            self.save_model(filename="agg_latest_model.pth")
            if top1 > self.best_acc:
                self.save_model(filename="agg_best_model.pth")
                self.best_acc = top1
                self.best_round = self.round
                print("[*] Save best agg. model ckpt")
            # best 항목을 eval_dict에 추가하고 'agg' 모드로 로깅
            eval_dict[mode + '/best_acc'] = self.best_acc
            eval_dict[mode + '/best_round'] = self.best_round
            _safe_log_to_wandb(eval_dict)

        elif mode == 'finetune':
            self.save_model(filename='finetune_latest_model.pth')
            if top1 > self.best_fine_acc:
                self.save_model(filename="finetune_best_model.pth")
                self.best_fine_acc = top1
                self.best_fine_round = self.round
                print("[*] Save best finetune model ckpt")
            # best 항목을 eval_dict에 추가하고 'finetune' 모드로 로깅
            eval_dict[mode + '/best_acc'] = self.best_fine_acc
            eval_dict[mode + '/best_round'] = self.best_fine_round
            _safe_log_to_wandb(eval_dict)

        else:
            # explicit_eval 등 그 외 모드도 해당 모드 이름으로 로깅
            _safe_log_to_wandb(eval_dict)
        # ---------------------------------------------------

        return eval_dict
