import torch
import torch.nn.functional as F
from collections import deque, defaultdict
import copy
import random
from torch.utils.data import DataLoader, RandomSampler, Dataset

from src.core.base import BaseClient
from src.core.utils import get_dataloader
from src.algorithms.utils import consistency_loss, masking, compute_pseudo_accuracy, ce_loss


class SCOMatchClient(BaseClient):
    def __init__(self, 
                 cid, 
                 config, 
                 net_builder, 
                 train_loader):
                #  run_name="client"):
        super().__init__(cid, config, net_builder, train_loader)
        self.device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.run_name=config['client_alg']
        # print(config['client_alg'])
        
        # FixMatch 기본 설정
        self.T = self.config['Training']['Client']['T']
        self.threshold = self.config['Training']['Client']['threshold']
        self.p_cutoff = self.config['Training']['Client']['threshold']  # p_cutoff와 threshold는 같은 값

        # SCOMatch 추가 설정
        self.num_classes = self.config['Dataset']['num_classes']
        self.ood_threshold = self.config['Training']['Client']['ood_threshold']
        # self.mu = self.config['Training']['Client']['mu']
        self.batch_size = self.config['Dataset']['Client']['bs']
        self.start_fix = self.config['Training']['start_fix']  # 언제부터 OOD 관련 손실 적용할지
        
        # OOD Memory Queue 설정
        self.selected_ood_maxlength = max(8 * self.num_classes, 256)
        self.selected_ood_update_length = self.config['Training']['Client']['Km']
        self.selected_ood_count = 0
        self.selected_ood_scores = deque(maxlen=self.selected_ood_maxlength)
        self.selected_ood_labels = deque(maxlen=self.selected_ood_maxlength)
        self.selected_ood_images = deque(maxlen=self.selected_ood_maxlength)
        
        # Confidence Score 저장
        self.all_sample_scores = [[] for i in range(self.num_classes+1)]
        self.threshold_update_freq = len(self.train_loader.dataset) // (self.batch_size * 2)
        
        self.use_ema = self.config['Model']['use_ema']
        
        # 배치 저장소 분리
        self.all_batches = []  # 전체 배치 저장 (L_unsup_open용)
        self.close_batches = []  # 일부 배치 저장 (L_unsup_close용)

    def train_step(self, x_ulb_w, x_ulb_s, x_close_w, x_close_s, y_ulb, 
                   pseudo_open, targets_open, mask_open,
                   pseudo_close, targets_close, mask_close,
                   ood_samples=None, ood_scores=None, epoch=0):
        self.optimizer.zero_grad()
        
        # OOD Memory Queue에서 샘플이 충분한 경우
        if ood_samples is not None and len(ood_samples) > 0:
            ood_scores=ood_scores.to(self.device)
            ood_label = (torch.ones(len(ood_samples)) * self.num_classes).to(self.device).long()
            
            # Open-set 작업용 입력
            inputs= torch.cat([ood_samples, x_ulb_w, x_ulb_s,x_close_w, x_close_s], 0).to(self.device)
            outputs = self.model(inputs)
            logits= outputs['logits']
            
            logits_ood_lb = logits[:len(ood_samples)]
            logits_open_w, logits_open_s,logits_close_w,logits_close_s = logits[len(ood_samples):].chunk(4)
            
            # L_sup_open 계산
            ood_mask = ood_scores < self.threshold
            L_sup_open = (ce_loss(logits_ood_lb, ood_label) * ood_mask).mean()
            
        else:
            # OOD 샘플이 부족한 경우
            inputs= torch.cat([x_ulb_w, x_ulb_s,x_close_w,x_close_s], 0).to(self.device)
            outputs = self.model(inputs)
            logits = outputs['logits']
            logits_open_w, logits_open_s,logits_close_w,logits_close_s = logits.chunk(4)
            L_sup_open = torch.zeros(1).to(self.device).mean()
        
        L_unsup_open = (consistency_loss(logits_open_s, targets_open) * mask_open).mean()
        
        logits_p_u_close_w = logits_close_w[:, :self.num_classes]  # inlier만 선택
        logits_p_u_close_s = logits_close_s[:, :self.num_classes]  # inlier만 선택
        
        L_unsup_close = (consistency_loss(logits_p_u_close_s, targets_close) * mask_close).mean()
    
            
        # 전체 손실 계산 (L_sup_close는 서버에서 계산)
        loss = L_sup_open + L_unsup_close + L_unsup_open
        
        loss.backward()
        if self.clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
            
        self.optimizer.step()
        
        return {
            'loss': float(loss.item()),
            'L_sup_open': float(L_sup_open.item()),
            'L_unsup_close': float(L_unsup_close.item()),
            'L_unsup_open': float(L_unsup_open.item())
        }

    def compute_dynamic_threshold(self):
        """동적 OOD 임계값 계산"""
        max_len = sum([len(self.all_sample_scores[i]) for i in range(self.num_classes)])
        ood_len = len(self.all_sample_scores[-1])
        
        if max_len > 0:
            ratio = ood_len / max_len
            ood_threshold = self.threshold * ratio
            return min(0.95, max(0.75, ood_threshold))
        else:
            return self.ood_threshold

    def fit(self, parameters, config):

        self.set_parameters(parameters)
        self.server_round = config.get('server_round', 0)
        
        # === 서버 모델로 pseudo-label 미리 계산 ===
        server_model = copy.deepcopy(self.model)
        server_model.eval()
    
        # 이번 라운드의 에폭 번호
        current_epoch = config.get('epoch', 0)
        
        # 매 라운드마다 서로 다른 데이터셋 구성을 위해 초기화
        self.unlabeled_dataset_open = copy.deepcopy(self.train_loader.dataset)
        self.unlabeled_dataset_close = copy.deepcopy(self.train_loader.dataset)
        
        # open_sampler = RandomSampler(self.unlabeled_dataset_open)
        # close_sampler = RandomSampler(self.unlabeled_dataset_close)
        
        open_loader = DataLoader(
            self.unlabeled_dataset_open,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.train_loader.num_workers,
            drop_last=False
        )
        
        close_loader = DataLoader(
            self.unlabeled_dataset_close,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.train_loader.num_workers,
            drop_last=False
        )
        
        # === 서버 모델로 전처리 ===
        dataset_dict_open = defaultdict(list)
        dataset_dict_close = defaultdict(list)
        
        # Open set 데이터 전처리
        for data in open_loader:
            batch = self.process_batch(**data)
            idx_ulb = data.get('idx_ulb', torch.arange(len(data['x_ulb_w'])))
            x_ulb_w = batch['x_ulb_w']
            y_ulb = batch['y_ulb']
            
            with torch.no_grad():
                outputs = server_model(x_ulb_w)
                logits = outputs['logits']
                
                # Open set pseudo-label 계산 (K+1 classes)
                pseudo_open = torch.softmax(logits / self.T, dim=-1)
                max_probs_open, targets_open = torch.max(pseudo_open, dim=-1)
                
                # Dynamic threshold 계산
                ood_threshold = self.compute_dynamic_threshold()
                mask_open = max_probs_open.ge(self.threshold) & (targets_open < self.num_classes)
                mask_open = mask_open | ((max_probs_open.ge(ood_threshold)) & (targets_open == self.num_classes))
                
                # MSP 기준으로 OOD Memory Queue update
                max_probs_msp, _ = torch.max(pseudo_open[:, :self.num_classes], dim=-1)
                _, indices = torch.sort(max_probs_msp)
                indices = indices[:min(self.selected_ood_update_length, len(indices))]
                
                if self.selected_ood_count < self.selected_ood_maxlength:
                    self.selected_ood_count += self.selected_ood_update_length
                for prob, img, ulab in zip(max_probs_msp[indices], x_ulb_w[indices], y_ulb[indices]):
                    self.selected_ood_scores.append(prob.item())
                    self.selected_ood_images.append(img)
                    self.selected_ood_labels.append(ulab.item())
                
                dataset_dict_open['idx_ulb'].append(idx_ulb)
                dataset_dict_open['pseudo_open'].append(pseudo_open.detach())
                dataset_dict_open['targets_open'].append(targets_open.detach())
                dataset_dict_open['mask_open'].append(mask_open.detach())
        
        # Close set 데이터 전처리
        for data in close_loader:
            batch = self.process_batch(**data)
            idx_ulb = data.get('idx_ulb', torch.arange(len(data['x_ulb_w'])))
            x_ulb_w = batch['x_ulb_w']
            
            with torch.no_grad():
                outputs = server_model(x_ulb_w)
                logits = outputs['logits']
                
                # Close set pseudo-label 계산 (inlier만)
                logits_close = logits[:, :self.num_classes]
                pseudo_close = torch.softmax(logits_close / self.T, dim=-1)
                max_probs_close, targets_close = torch.max(pseudo_close, dim=-1)
                
                # Open set logits for ID mask
                pseudo_open_full = torch.softmax(logits / self.T, dim=-1)
                max_probs_open_full, targets_open_full = torch.max(pseudo_open_full, dim=-1)
                
                mask_close = max_probs_close.ge(self.threshold).float()
                id_mask = (targets_open_full < self.num_classes)
                final_mask_close = mask_close * id_mask
                
                dataset_dict_close['idx_ulb'].append(idx_ulb)
                dataset_dict_close['pseudo_close'].append(pseudo_close.detach())
                dataset_dict_close['targets_close'].append(targets_close.detach())
                dataset_dict_close['mask_close'].append(final_mask_close.detach())
        
        # Precomputed datasets 생성
        final_dict_open = {k: torch.cat(v, dim=0).cpu() for k, v in dataset_dict_open.items()}
        final_dict_close = {k: torch.cat(v, dim=0).cpu() for k, v in dataset_dict_close.items()}
        
        open_dataset = PrecomputedDataset(
            self.unlabeled_dataset_open,
            final_dict_open,
            'open'
        )
        
        close_dataset = PrecomputedDataset(
            self.unlabeled_dataset_close,
            final_dict_close,
            'close'
        )
        
        open_loader = DataLoader(
            open_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.train_loader.num_workers,
            drop_last=True
        )
        
        close_loader = DataLoader(
            close_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.train_loader.num_workers,
            drop_last=True
        )
        
        # 기존 OOD 메모리 큐 및 스코어 정보 초기화
        # epoch_ood_update_info = []
        
        stats = {
            'tot_samples': 0,
            'tot_correct': 0,
            'pmasked_samples': 0,
            'pmasked_correct': 0,
            'masked_samples': 0,
            'masked_correct': 0,
        }
        
        # 모든 배치 수집 (두 데이터로더에서 개별적으로)
        batch_idx = 0
        
        # 두 데이터로더의 길이 중 작은 값으로 제한 (길이가 다를 수 있음)
        min_batches = min(len(open_loader), len(close_loader))
        
        # 각 데이터로더에서 이터레이터 생성
        open_iter = iter(open_loader)
        close_iter = iter(close_loader)
        
        self.model.train()
        
        # 로컬 학습 실행
        for _ in range(self.local_epochs):
            # 각 배치 쌍에 대해 학습
            for _ in range(min_batches):
                # 두 데이터로더에서 완전히 다른 배치 가져오기
                open_batch = next(open_iter)
                close_batch = next(close_iter)

                # batch마다 OOD Memory Queue에서 OOD sampling
                if self.selected_ood_count >= self.batch_size:
                    indices = torch.randperm(len(self.selected_ood_images))[:self.batch_size]
                    ood_samples = torch.stack([self.selected_ood_images[i] for i in indices])
                    ood_scores = torch.tensor([self.selected_ood_scores[i] for i in indices])
                else:
                    ood_samples = None
                    ood_scores = None
                
                # 배치 전처리
                open_batch = self.process_batch(**open_batch)
                close_batch = self.process_batch(**close_batch)
                
                # 주기적으로 임계값 업데이트
                if batch_idx % self.threshold_update_freq == 0 and batch_idx > 0 and current_epoch >= self.start_fix:
                    self.all_sample_scores = [[] for i in range(self.num_classes + 1)]
                
                # 두 배치 데이터 조합 (미리 계산된 값들 포함)
                combined_batch = {
                    'x_ulb_w': open_batch['x_ulb_w'],
                    'x_ulb_s': open_batch['x_ulb_s'],
                    'x_close_w': close_batch['x_ulb_w'],
                    'x_close_s': close_batch['x_ulb_s'],
                    'y_ulb': open_batch['y_ulb'],
                    'pseudo_open': open_batch['pseudo_open'],
                    'targets_open': open_batch['targets_open'],
                    'mask_open': open_batch['mask_open'],
                    'pseudo_close': close_batch['pseudo_close'],
                    'targets_close': close_batch['targets_close'],
                    'mask_close': close_batch['mask_close'],
                    'ood_samples': ood_samples,
                    'ood_scores': ood_scores,
                    'epoch': current_epoch
                }
                
                # 통계 수집 (미리 계산된 값 사용)
                stats['tot_samples'] += len(open_batch['pseudo_open'])
                # pseudo accuracy는 미리 계산된 값으로 대체 가능
                
                # 학습 스텝 실행
                res = self.train_step(**combined_batch)
                # epoch_ood_update_info.extend(res['ood_update_info'])
                
                batch_idx += 1
                
            # 매 에폭마다 이터레이터 재설정
            open_iter = iter(open_loader)
            close_iter = iter(close_loader)
        
        
        self.res_dict = {
            'tot_samples': stats['tot_samples'],
            'tot_pseudo_acc': 0.0,  # 미리 계산된 값으로 대체 가능
            'pmasked_samples': 0,
            'pmasked_pseudo_acc': 0.0,
            'masked_samples': 0,
            'masked_pseudo_acc': 0.0,
            'selected_ood_count': self.selected_ood_count
        }
        
        return self.get_parameters(), len(self.train_loader.dataset), self.res_dict


class PrecomputedDataset(Dataset):
    def __init__(self, orig_dataset, final_dict, data_type):
        self.orig_dataset = orig_dataset
        self.final_dict = final_dict
        self.data_type = data_type  # 'open' or 'close'

    def __len__(self):
        return len(self.final_dict['idx_ulb'])

    def __getitem__(self, i):
        idx = self.final_dict['idx_ulb'][i].item()
        sample = self.orig_dataset[idx]
        
        if self.data_type == 'open':
            sample.update({
                'pseudo_open': self.final_dict['pseudo_open'][i],
                'targets_open': self.final_dict['targets_open'][i],
                'mask_open': self.final_dict['mask_open'][i],
            })
        else:  # close
            sample.update({
                'pseudo_close': self.final_dict['pseudo_close'][i],
                'targets_close': self.final_dict['targets_close'][i],
                'mask_close': self.final_dict['mask_close'][i],
            })
        
        return sample