import numpy as np
import torch
import torch.nn.functional as F
import os
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score

from src.core.base import BaseServer
from src.algorithms.utils import (
    ce_loss, 
    beta_pdf,
    estimate_beta
)
from src.algorithms.network import ProSubNet


class ProSubServer(BaseServer):
    def __init__(self, config, net_builder, train_loader, test_loader, logger):
        self.config = config
        self.logger = logger

        # model hidden
        self.cls_hidden = self.config['Model']['cls_hidden']
        self.proj_hidden = self.config['Model']['proj_hidden']
        self.proj_size = self.config['Model']['proj_size']

        super().__init__(config, net_builder, train_loader, test_loader, logger)    

        
        # Training hyperparams
        self.wd = self.config['Training']['Server']['wd']
        self.alpha1 = self.config['Training']['alpha1']
        self.beta1 = self.config['Training']['beta1']
        self.alpha2 = self.config['Training']['alpha2']
        self.beta2 = self.config['Training']['beta2']
        self.pi = self.config['Training']['pi']
        self.prototype_ema_decay = self.config['Training']['prototype_ema_decay']
        self.beta_ema_decay = self.config['Training'].get('beta_ema_decay', 0.999)
        self.use_ema = self.config['Model']['use_ema']

        # feature dimension
        # feat_dim = self.model.num_features
        self.prototypes = None  # [C, D]
        self.use_ema_feat = self.config['Training']['Server']['use_ema_feat']
        

        # checkpoint path
        self.ckpt_dir = os.path.join(self.config['save_dir'], self.config['save_name'])
        os.makedirs(self.ckpt_dir, exist_ok=True)


    def set_model(self):
        model = super().set_model()
        model = ProSubNet(
            base=model, 
            num_classes=self.num_classes, 
            cls_hidden=self.cls_hidden,
            proj_hidden=self.proj_hidden,
            proj_size=self.proj_size
        )
        return model

    def set_ema_model(self):
        ema_model = self.net_builder(num_classes=self.num_classes)        
        ema_model = ProSubNet(
            base=ema_model, 
            num_classes=self.num_classes, 
            cls_hidden=self.cls_hidden,
            proj_hidden=self.proj_hidden,
            proj_size=self.proj_size
        )
        ema_model.load_state_dict(self.model.state_dict())
        return ema_model


    @torch.no_grad()
    def update_feat_mean(self, f, labels):
        """
        TensorFlow 스타일: 매 배치마다 prototype을 EMA 업데이트
        """
        feats_l = torch.zeros(self.num_classes, f.size(1), device=f.device)  
        cnts_l = torch.zeros(self.num_classes, device=f.device)

        unique_labels_l, local_cnts_l = labels.unique(return_counts=True)
        feats_l.index_add_(0, labels, f)
        cnts_l.index_add_(0, unique_labels_l, local_cnts_l.float())

        valid_cnt = cnts_l.unsqueeze(1).clone()
        valid_cnt[valid_cnt == 0] = 1
        mean_class_prototypes = feats_l / valid_cnt  # [C, D]

        if self.prototypes is None:
            self.prototypes = mean_class_prototypes.clone().detach()
        else:
            self.prototypes = (
                self.prototypes * self.prototype_ema_decay
                + (1 - self.prototype_ema_decay) * mean_class_prototypes.clone().detach()
            )

        # 저장
        proto_path = os.path.join(self.ckpt_dir, f"prototypes_r{self.round}_e{self.epoch}.pt")
        torch.save(self.prototypes.detach().cpu(), proto_path)
        self.current_proto_path = proto_path


    def p_id(self, x):
        pdf_id = beta_pdf(x, self.alpha1, self.beta1)
        pdf_ood = beta_pdf(x, self.alpha2, self.beta2)
        pdf_joint = self.pi * pdf_id + (1 - self.pi) * pdf_ood
        return self.pi * pdf_id / (pdf_joint + 1e-8)


    def train_step(self, optimizer, x_lb, x_lb_w, x_lb_s0, x_lb_s1, y_lb):
        self.model.train()
        self.optimizer.zero_grad()

        inputs = torch.cat((x_lb, x_lb_w, x_lb_s0, x_lb_s1))
        outputs = self.model(inputs)

        # 분할
        logits_lb, logits_weak, logits_strong, _ = outputs['logits'].chunk(4)
        embeds_lb, embeds_weak, embeds_strong, _ = outputs['feat'].chunk(4)

        # supervised loss
        xe_loss = ce_loss(logits_lb.float(), y_lb, reduction='mean')

        # L2 regularization
        if self.mode == 'finetune':
            variables = [param for name, param in self.model.named_parameters() if 'weight' in name]
            l2_loss = sum((param ** 2).sum() for param in variables)
            l2_loss = self.wd * l2_loss
            loss = xe_loss + l2_loss
        else:
            l2_loss = torch.tensor(0.0, device=xe_loss.device)
            loss = xe_loss

        loss.backward() 
        if self.clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad)
        optimizer.step()

        lr = optimizer.param_groups[0]['lr']

        # === Prototype update (매 배치) ===
        if self.use_ema_feat:
            self.ema.apply_shadow()
            with torch.no_grad():
                feat_x_w = self.model(x_lb_w)['feat']
            self.ema.restore()
            self.update_feat_mean(f=feat_x_w.detach(), labels=y_lb)
        else:
            self.update_feat_mean(f=embeds_lb.detach(), labels=y_lb)

        self.update_feat_mean(f=embeds_weak.detach(), labels=y_lb)

        # === Alpha/Beta IMM+EMA 업데이트 (매 배치) ===
        class_prototypes = self.prototypes.to(self.device)
        q, _ = torch.linalg.qr(class_prototypes.T)
        qt = q.T

        proj2subspace = torch.matmul(torch.matmul(embeds_lb, q), qt)
        proj_norm = torch.norm(proj2subspace, dim=-1, p=2)
        embed_norm = torch.norm(embeds_lb, dim=-1, p=2)
        subspacescore = proj_norm / (embed_norm + 1e-8)

        id_weights = torch.ones_like(subspacescore)  # labeled data → ID only
        alpha1_new, beta1_new = estimate_beta(subspacescore.detach(), id_weights)

        self.alpha1 = self.beta_ema_decay * float(self.alpha1) + (1 - self.beta_ema_decay) * float(alpha1_new)
        self.beta1  = self.beta_ema_decay * float(self.beta1) + (1 - self.beta_ema_decay) * float(beta1_new)

        self.alpha1 = float(np.clip(self.alpha1, 0.1, 100.0))
        self.beta1  = float(np.clip(self.beta1, 0.1, 100.0))

        res_dict = {
            f"{self.mode}_train/supervised_loss": xe_loss.item(),
            f"{self.mode}_train/l2_loss": l2_loss.item(),
            f"{self.mode}_train/lr": lr,
            f"{self.mode}_train/alpha1": self.alpha1,
            f"{self.mode}_train/beta1": self.beta1,
        }
        return res_dict
    
    @torch.no_grad()
    def evaluate(self, mode="warmup"):
        self.print_fn(f">> eval round: {self.round}")
        self.print_fn(f">> eval epochs: {self.epoch}")
        
        model = self.ema_model if self.use_ema else self.model
        model.eval()

        total_loss = 0.0
        total_num = 0.0

        y_true = []
        y_in_true = []
        y_pred = []

        ood_scores_msp = []
        ood_scores_entropy = []
        ood_scores_energy = []
        ood_scores_subspace = []   
        ood_labels = []
        
        with torch.no_grad():
            for data in self.test_loader:
                x = data['x_lb']
                y = data['y_lb']

                if isinstance(x, dict):
                    x = {k: v.to(self.device) for k, v in x.items()}
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                
                outputs = model(x)
                logits = outputs['logits']
                embeds = outputs['feat']  
                
                # Standard classification
                in_idx = torch.where(y < self.num_classes)[0]
                if len(in_idx) > 0:
                    preds = torch.max(logits[in_idx], dim=-1)[1]
                    loss = F.cross_entropy(logits[in_idx], y[in_idx], reduction='mean')
                    total_loss += loss.item() * in_idx.shape[0]
                    total_num += in_idx.shape[0]

                    y_in_true.extend(y[in_idx].cpu().tolist())
                    y_pred.extend(preds.cpu().tolist())

                y_true.extend(y.cpu().tolist())
                ood_labels.extend((y >= self.num_classes).int().cpu().tolist())

                # OOD scores
                probs = F.softmax(logits, dim=1)
                max_prob, _ = probs.max(dim=1)
                entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
                energy = -torch.logsumexp(logits, dim=1)

                ood_scores_msp.extend((1.0 - max_prob).cpu().tolist())
                ood_scores_entropy.extend(entropy.cpu().tolist())
                ood_scores_energy.extend(energy.cpu().tolist())     

                # ✅ Subspace score 계산
                if self.prototypes is not None:
                    class_prototypes = self.prototypes.to(self.device)
                    q, _ = torch.linalg.qr(class_prototypes.T)
                    qt = q.T
                    proj2subspace = (embeds @ q) @ qt
                    proj_norm = torch.norm(proj2subspace, dim=-1, p=2)
                    embed_norm = torch.norm(embeds, dim=-1, p=2)
                    subspacescore = proj_norm / (embed_norm + 1e-8)
                    ood_scores_subspace.extend(subspacescore.cpu().tolist())
        
        # Accuracy
        y_in_true = np.array(y_in_true)
        y_pred = np.array(y_pred)
        top1 = accuracy_score(y_in_true, y_pred) if len(y_in_true) > 0 else 0.0

        if self.num_classes <= 10 and len(y_in_true) > 0:
            cf_mat = confusion_matrix(y_in_true, y_pred, normalize='true')
            self.print_fn("confusion matrix:\n" + np.array_str(cf_mat))

        # AUROC
        try:
            auroc_msp = roc_auc_score(ood_labels, ood_scores_msp)
            auroc_entropy = roc_auc_score(ood_labels, ood_scores_entropy)
            auroc_energy = roc_auc_score(ood_labels, ood_scores_energy)
            auroc_subspace = roc_auc_score(ood_labels, ood_scores_subspace) if len(ood_scores_subspace) > 0 else 0.0
        except ValueError:
            auroc_msp = auroc_entropy = auroc_energy = auroc_subspace = 0.0

        model.train()                
        
        # Logging
        eval_dict = {
            mode+'/loss': total_loss / total_num if total_num > 0 else 0.0,
            mode+'/top-1-acc': top1,
            mode+'/auroc_msp': auroc_msp,
            mode+'/auroc_entropy': auroc_entropy,
            mode+'/auroc_energy': auroc_energy,
            mode+'/auroc_subspace': auroc_subspace,   # ✅ 추가
            mode+'/round': self.round
            } 
        if mode == 'agg':
            self.save_model(filename="agg_latest_model.pth")
            if top1 > self.best_acc:
                self.save_model(filename="agg_best_model.pth")
                self.best_acc = top1
                self.best_round = self.round
                self.print_fn("[*] Save best agg. model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_acc
            eval_dict[f"{mode}/best_round"] = self.best_round
        elif mode == 'finetune':
            self.save_model(filename="finetune_latest_model.pth")
            if top1 > self.best_fine_acc:
                self.save_model(filename="fnetune_best_model.pth")
                self.best_fine_acc = top1
                self.best_fine_round = self.round
                self.print_fn("[*] Save best finetune model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_fine_acc
            eval_dict[f"{mode}/best_round"] = self.best_fine_round

        if self.use_wandb:
            self.run.log(eval_dict, step=self.epoch + self.round)

        self.print_fn(eval_dict)
        return eval_dict