import copy
from collections import defaultdict

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from src.core.base import BaseClient
from src.core.utils import get_dataloader 
from src.algorithms.utils import (
    # ce_loss,
    # consistency_loss, 
    # ova_ulb, 
    # ova_ent, 
    # ova_socr, 
    # proto_contrastive_ulb_loss, 
    # masking, 
    # inlier_masking, 
    # compute_pseudo_accuracy,
    beta_pdf,
    estimate_beta
    )

from src.algorithms.network import ProSubNet


# ==============================
# backbone ----> projection ---> proj
#            |
#       logit,feat(embed)
# ==============================

class ProSubClient(BaseClient):
    def __init__(self, 
                 cid, 
                 config, 
                 net_builder, 
                 train_loader):
        
        self.config = config
        # model hidden
        self.cls_hidden = self.config['Model']['cls_hidden']
        self.proj_hidden = self.config['Model']['proj_hidden']
        self.proj_size = self.config['Model']['proj_size']
        self.tau= self.config['Training']['Client']['tau']

        
        super().__init__(cid, config, net_builder, train_loader)

        self.wpl = self.config['Training']['Client']['wpl']
        self.ws = self.config['Training']['Client']['ws']
        self.wsub = self.config['Training']['Client']['wsub']
        self.wd = self.config['Training']['Client']['wd']

        self.alpha1 = self.config['Training']['alpha1']
        self.alpha2 = self.config['Training']['alpha2']
        self.beta1 = self.config['Training']['beta1']
        self.beta2 = self.config['Training']['beta2']

        self.pi = self.config['Training']['pi']
        self.p_cutoff = self.config['Training']['Client']['p_cutoff']
        self.beta_ema_decay = self.config['Training']['beta_ema_decay']
        
        self.class_prototypes = torch.zeros(self.num_classes, self.proj_size, dtype=torch.float32)
        self.class_prototypes = torch.nn.Parameter(self.class_prototypes, requires_grad=False)

        # === epoch-level buffers for alpha/beta update ===
        self._epoch_subspacescore = []
        self._epoch_id_weights = []
        self._epoch_ood_weights = []

    def set_model(self):
        model = super().set_model()  # backbone
        model = ProSubNet(base=model, 
                       num_classes=self.num_classes, 
                       cls_hidden=self.cls_hidden,
                       proj_hidden=self.proj_hidden,
                       proj_size=self.proj_size)
        return model

    
  
    def p_id(self, sims):
        # Beta PDF 기반 ID 확률 계산
        pdf_id = beta_pdf(sims, self.alpha1, self.beta1)
        pdf_ood = beta_pdf(sims, self.alpha2, self.beta2)

        pdf_joint = self.pi * pdf_id + (1 - self.pi) * pdf_ood
        p_id = self.pi * pdf_id / (pdf_joint + 1e-8)  # 원본과 동일하게 1e-1 사용
        return p_id
    

    def train_step(self, x_ulb, x_ulb_w, x_ulb_s, y_ulb, p, p_mask, prototypes, q, r, qt):
        self.model.train()
        self.optimizer.zero_grad()

        # === DEBUG: Print y_ulb distribution ===
        if y_ulb is not None:
            print("[DEBUG][Client] y_ulb shape:", y_ulb.shape)
            print("[DEBUG][Client] y_ulb unique:", torch.unique(y_ulb, return_counts=True))

        batch = x_ulb.shape[0]

        # 입력 합치기 (unlabeled weak/strong)
        inputs = torch.cat([x_ulb, x_ulb_w, x_ulb_s], dim=0)

        # Forward
        outputs = self.model(inputs)
        logits_ulb, logits_weak, logits_strong = outputs['logits'].chunk(3)
        embeds_ulb, embeds_weak, embeds_strong = outputs['feat'].chunk(3)
        projs_ulb, projs_weak, projs_strong = outputs['proj'].chunk(3)

        # === 1. Pseudo-labeling (use precomputed) ===
        pseudo_targets = p.argmax(dim=1)  # 미리 계산된 pseudo-label 사용
        fixmatch_mask = p_mask.float()  # 미리 계산된 mask 사용

        xep_loss = F.cross_entropy(logits_strong, pseudo_targets, reduction='none')

        # --- Subspace score ---
        proj2subspace = embeds_weak @ (q @ qt)
        proj2subspace = proj2subspace.detach()
        norm_projs = F.normalize(proj2subspace, dim=-1, p=2)
        norm_embeds = F.normalize(embeds_weak, dim=-1, p=2)
        subspacescore = torch.sum(norm_projs * norm_embeds, dim=-1)

        norm_embeds_weak = F.normalize(embeds_weak.detach(), dim=-1, p=2)

        p_id = self.p_id(subspacescore.detach())
        rand = torch.rand_like(p_id)
        id_mask = (p_id >= rand).float()
        pseudo_mask = fixmatch_mask * id_mask
        ood_mask = 1.0 - id_mask

        xep_loss = torch.mean(xep_loss * pseudo_mask)

        sub_loss = torch.mean((ood_mask - id_mask) * subspacescore)

        norm_projs_strong = F.normalize(projs_strong, dim=-1, p=2)
        doublematch_similarity = torch.sum(norm_projs_strong * norm_embeds_weak, dim=-1)
        us_loss = torch.mean(-doublematch_similarity + 1)

        # === L2 regularization
        variables = [param for name, param in self.model.named_parameters() if 'weight' in name]
        l2_loss = 0.5 * sum((param ** 2).sum() for param in variables)
        l2_loss = self.wd * l2_loss

        # === Total loss
        loss = self.ws * us_loss + self.wpl * xep_loss + self.wsub * sub_loss + self.wd * l2_loss
        loss.backward()
        if self.clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
        self.optimizer.step()

        # === IMM + EMA update of alpha/beta ===
        with torch.no_grad():
            # responsibility 계산
            p1 = beta_pdf(subspacescore, self.alpha1, self.beta1)
            p2 = beta_pdf(subspacescore, self.alpha2, self.beta2)
            w1 = self.pi * p1
            w2 = (1 - self.pi) * p2
            p_prob = w1 / (w1 + w2 + 1e-8)

            id_weights = p_prob.detach()
            ood_weights = (1.0 - p_prob).detach()
            print(f"Client {self.cid} - id_weights: {id_weights.mean().item():.4f}, ood_weights: {ood_weights.mean().item():.4f}")

            # ema_decay = self.config['Training']['beta_ema_decay']

            if id_weights.sum() > 0:
                alpha1_new, beta1_new = estimate_beta(subspacescore.detach(), id_weights)
                alpha1_new = alpha1_new.detach().cpu().item()
                beta1_new  = beta1_new.detach().cpu().item()
                self.alpha1 = float(self.beta_ema_decay * float(self.alpha1) + (1 - self.beta_ema_decay) * alpha1_new)
                self.beta1  = float(self.beta_ema_decay * float(self.beta1) + (1 - self.beta_ema_decay) * beta1_new)

            if ood_weights.sum() > 0:
                alpha2_new, beta2_new = estimate_beta(subspacescore.detach(), ood_weights)
                print(f"Client {self.cid} - alpha2_new: {alpha2_new}, beta2_new: {beta2_new}")
                alpha2_new = alpha2_new.detach().cpu().item()
                beta2_new  = beta2_new.detach().cpu().item()
                self.alpha2 = float(self.beta_ema_decay * float(self.alpha2) + (1 - self.beta_ema_decay) * alpha2_new)
                self.beta2  = float(self.beta_ema_decay * float(self.beta2) + (1 - self.beta_ema_decay) * beta2_new)

            # # 안정성 보정
            # self.alpha1 = float(np.clip(self.alpha1, 0.1, 100.0))
            # self.beta1  = float(np.clip(self.beta1, 0.1, 100.0))
            # self.alpha2 = float(np.clip(self.alpha2, 0.1, 100.0))
            # self.beta2  = float(np.clip(self.beta2, 0.1, 100.0))


        res_dict = {
            "train/xep_loss": xep_loss.item(),
            "train/sub_loss": sub_loss.item(),
            "train/l2_loss": l2_loss.item(),
            "train/us_loss": us_loss.item(),
            "train/total_loss": loss.item(),
            "train/lr": self.optimizer.param_groups[0]['lr'],
        }

        return res_dict


    def fit(self, parameters, config):
        # 서버에서 전달받은 모델 파라미터 세팅
        self.set_parameters(parameters)
        # self.server_round = config.get("server_round", 0)
        
        # === 서버 모델로 pseudo-label 계산 (precompute) ===
        server_model = copy.deepcopy(self.model)
        server_model.eval()

        dataset_dict = defaultdict(list)
        
        for data in self.train_loader:
            batch = self.process_batch(**data)
            idx_ulb = data.get('idx_ulb', torch.arange(len(data['x_ulb'])))
            x_ulb = batch['x_ulb']
            y_ulb = batch['y_ulb']
            
            with torch.no_grad():
                outputs = server_model(x_ulb)
                logits = outputs['logits']
                
                # Pseudo-label 계산
                pseudo_labels = torch.softmax(logits / self.tau, dim=-1)
                p_mask = (pseudo_labels.max(dim=1)[0] >= self.tau)
                
                dataset_dict['idx_ulb'].append(idx_ulb)
                dataset_dict['p'].append(pseudo_labels.detach())
                dataset_dict['p_mask'].append(p_mask.detach())
                
        final_dict = {k: torch.cat(v, dim=0).cpu() for k, v in dataset_dict.items()}
        
        # === PrecomputedDataset으로 데이터로더 생성 ===
        prosub_dataset = PrecomputedDataset(self.train_loader.dataset,
                                             final_dict
                                             )

        prosub_loader = DataLoader(
            prosub_dataset,
            batch_size=self.data_cfgs['bs'],
            shuffle=True,
            num_workers=self.data_cfgs['num_workers'],
            drop_last=True
        )


        # === 프로토타입 로드 및 QR 분해 준비 ===
        proto_path = config.get("prototype_path", None)
        if proto_path is not None:
            prototypes = torch.load(proto_path, map_location=self.device, weights_only=True)
            # Debug: print prototypes.T dimension
            try:
                print(f"[DEBUG][Client] Loaded prototypes from {proto_path}, prototypes.T.shape={prototypes.T.shape}")
            except Exception:
                print(f"[DEBUG][Client] Loaded prototypes from {proto_path}, but failed to read .T.shape")

            q, r = torch.linalg.qr(prototypes.T, mode='complete')
            qt = q.T
        else:
            prototypes, q, r, qt = None, None, None, None

        results = defaultdict(list)

        # === 데이터 반복 ===
        for batch in prosub_loader:
            batch = self.process_batch(**batch)

            x_ulb, x_ulb_w, x_ulb_s, y_ulb, p, p_mask = (
                batch['x_ulb'], batch['x_ulb_w'], batch['x_ulb_s'], 
                batch['y_ulb'], batch['p'], batch['p_mask']
            )

            res_dict = self.train_step(
                x_ulb=x_ulb,
                x_ulb_w=x_ulb_w,
                x_ulb_s=x_ulb_s,
                y_ulb=y_ulb,
                p=p,
                p_mask=p_mask,
                prototypes=prototypes,
                q=q,
                r=r,
                qt=qt
            )

            for k, v in res_dict.items():
                results[k].append(v)

        # === 에폭 단위 결과 평균 ===
        res_dict = {k: sum(v) / len(v) for k, v in results.items()}

        metrics = {
        **res_dict,
        "alpha1": self.alpha1,
        "beta1": self.beta1,
        "alpha2": self.alpha2,
        "beta2": self.beta2,
            }

        return self.get_parameters(), len(self.train_loader.dataset), metrics


    def __len__(self):
        return self.x_ulb_w.shape[0]

    def __getitem__(self, idx):
        return {
                'x_ulb': self.x_ulb[idx],
                'x_ulb_w': self.x_ulb_w[idx],
                'x_ulb_s': self.x_ulb_s[idx],
                'y_ulb': self.y_ulb[idx],
                'p': self.p[idx],
                'p_mask': self.p_mask[idx],
                }


class PrecomputedDataset(Dataset):
    def __init__(self, orig_dataset, final_dict):
        self.orig_dataset = orig_dataset
        self.final_dict = final_dict

    def __len__(self):
        return len(self.final_dict['idx_ulb'])

    def __getitem__(self, i):
        idx = self.final_dict['idx_ulb'][i].item()
        sample = self.orig_dataset[idx]

        sample.update({
            'idx_ulb': self.final_dict['idx_ulb'][i],
            'p': self.final_dict['p'][i],
            'p_mask': self.final_dict['p_mask'][i],
        })
        return sample
