import copy
from collections import defaultdict

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

from src.core.base import BaseClient
from src.core.utils import get_dataloader, set_seed
from src.algorithms.utils import (
    consistency_loss, 
    masking, 
    compute_pseudo_accuracy
)


class FixMatchClient(BaseClient):
    def __init__(self, 
                 cid, 
                 config, 
                 net_builder, 
                 train_loader):
        
        super().__init__(cid, config, net_builder, train_loader)
        
        self.T = self.train_cfgs['T']
        self.p_cutoff = self.train_cfgs['p_cutoff']
        self.use_mask_only = self.train_cfgs['use_mask_only']
        self.filtered_weight = self.train_cfgs['filtered_weight']

    
    def train_step(self, x_ulb_s, p, p_mask):
        
        self.model.train()
        if self.use_mask_only:
            p_mask = None
            
        logits_x_ulb_s = self.model(x_ulb_s)['logits']
        
        loss = consistency_loss(logits_x_ulb_s, p, 'ce', mask=p_mask)
        self.optimizer.zero_grad()
        loss.backward()
        if self.clip_grad > 0:
            total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)

        self.optimizer.step()

        return {'consistency_loss': float(loss.item())}


    def fit(self, parameters, config):
        self.set_parameters(parameters)
        
        self.server_round = config["server_round"]
        set_seed(self.seed + self.server_round + self.cid)
        
        model = copy.deepcopy(self.model)
        if self.use_ema_pseudo and (self.server_round > self.start_ema_pseudo):
            ema_ckpt = torch.load(self.ema_save_path, map_location=self.device, weights_only=True)
            model.load_state_dict(ema_ckpt)
            print("==> Load ema ckpt for pseudo-labeling **")
        model.eval()

        dataset_dict = defaultdict(list)
        stats = {
                'tot_samples': 0,
                'tot_correct': 0,
                'pmasked_samples': 0,
                'pmasked_correct': 0,
                }
        
        for data in self.train_loader:
            batch = self.process_all_batch(**data)
            idx_ulb = batch['idx_ulb']
            x_ulb_w = batch['x_ulb_w']
            y_ulb = batch['y_ulb']
            
            with torch.no_grad():
                logits = model(x_ulb_w)['logits']
                p = F.softmax(logits / self.T, dim=-1)
                p_mask = masking(p, softmax_x_ulb=False, cutoff=self.p_cutoff)
                
                tot_samples, tot_correct, pmasked_samples, pmasked_correct = compute_pseudo_accuracy(p, y_ulb, p_mask)
                
            dataset_dict['idx_ulb'].append(idx_ulb)
            dataset_dict['p'].append(p.detach())
            dataset_dict['p_mask'].append(p_mask.detach())
                
            stats['tot_samples'] += tot_samples
            stats['tot_correct'] += tot_correct
            stats['pmasked_samples'] += pmasked_samples
            stats['pmasked_correct'] += pmasked_correct                                    
        
        final_dict = {k: torch.cat(v, dim=0).cpu() for k, v in dataset_dict.items()}

        if self.use_mask_only:
            mask = final_dict['p_mask'] == 1
            final_dict = {k: v[mask] for k, v in final_dict.items()}
        
        new_dataset = PrecomputedDataset(
            orig_dataset=self.train_loader.dataset,  # BasicDataset 인스턴스
            orig_idx=final_dict['idx_ulb'],
            p=final_dict['p'],
            p_mask=final_dict['p_mask']
        )

        self.train_loader = get_dataloader(new_dataset, 
                                           batch_size=self.data_cfgs['bs'], 
                                           shuffle=True,
                                           num_workers=self.data_cfgs['num_workers'],
                                           drop_last=True)
            
            
        curr_lr = float(config.get("current_lr", self.train_cfgs['lr']))
        for g in self.optimizer.param_groups:
            g['lr'] = curr_lr

        self.model.train()
        for _ in range(self.local_epochs):
            for data in self.train_loader:
                res = self.train_step(**self.process_batch(**data))

        self.res_dict = {
            'tot_samples': stats['tot_samples'],
            'tot_pseudo_acc': stats['tot_correct'] / stats['tot_samples']
                                if stats['tot_samples'] > 0 else 0.0,
            'pmasked_samples': stats['pmasked_samples'],
            'pmasked_pseudo_acc': stats['pmasked_correct'] / stats['pmasked_samples']
                                if stats['pmasked_samples'] > 0 else 0.0,
            'masked_samples': 0,
            'masked_pseudo_acc': 0.0,
        }

        if self.filtered_weight:
            num_data = pmasked_samples
            if num_data == 0:
                num_data = 1
        else:
            num_data = len(self.train_loader.dataset)
            
        return self.get_parameters(), num_data, self.res_dict
    

class PrecomputedDataset(Dataset):
    def __init__(self, orig_dataset, orig_idx, p, p_mask):
        self.orig_dataset = orig_dataset
        self.orig_idx = orig_idx
        self.p = p
        self.p_mask = p_mask

    def __len__(self):
        return len(self.orig_idx)

    def __getitem__(self, i):
        idx = self.orig_idx[i].item()  
        sample = self.orig_dataset[idx]  

        sample.update({
            'p': self.p[i],
            'p_mask': self.p_mask[i],
        })
        return sample