# pseudo labeling from server model
import copy
from collections import defaultdict

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

from src.core.base import BaseClient
from src.core.utils import get_dataloader, set_seed
from src.algorithms.utils import (
    consistency_loss, 
    ova_ulb, 
    ova_ent, 
    ova_socr, 
    masking, 
    inlier_p, 
    compute_pseudo_accuracy
    )

from src.algorithms.network import OpenNet


class OpenMatchClient(BaseClient):
    def __init__(self, 
                 cid, 
                 config, 
                 net_builder, 
                 train_loader):
        
        self.config = config
        # model architecture
        self.cls_hidden = self.config['Model']['cls_hidden']
        self.out_hidden = self.config['Model']['out_hidden']
        self.mlp = self.config['Model']['mlp']
        
        super().__init__(cid, config, net_builder, train_loader)

        self.T = self.train_cfgs['T']
        self.p_cutoff = self.train_cfgs['p_cutoff']
        self.in_cutoff = self.train_cfgs['in_cutoff']
        self.neg_cutoff = self.train_cfgs['neg_cutoff']
        
        self.lambda_ova = self.config['Training']['Server']['lambda_ova']
        self.lambda_oem = self.train_cfgs['lambda_oem']
        self.lambda_socr = self.train_cfgs['lambda_socr']
        self.lambda_neg = self.train_cfgs['lambda_neg']
        
        self.local_filtered_epochs = self.train_cfgs['local_filtered_epochs']
        self.start_filtered_epochs = self.train_cfgs['start_filtered_epochs_round']
        
        self.filtered_weight = self.train_cfgs['filtered_weight']
        
        self.cls_lr_factor = self.train_cfgs['cls_lr_factor']
        

    def set_model(self):
        model = super().set_model()  # backbone
        model = OpenNet(base=model, 
                        num_classes=self.num_classes, 
                        cls_hidden=self.cls_hidden,
                        out_hidden=self.out_hidden,
                        mlp=self.mlp)
        return model

    
    def train_step(self, 
                   x_ulb, x_ulb_w, x_ulb_s, 
                   p, p_mask, final_mask, neg_mask):
        
        self.model.train()

        if self.use_pmask:  # TRUE for SSB / FALSE for OpenMatch
            mask = p_mask
        else:
            mask = final_mask

        inputs = torch.cat((x_ulb, x_ulb_w, x_ulb_s))
        outputs = self.model(inputs)
        
        _, _, logits_x_ulb_s = outputs['logits'].chunk(3)
        logits_out_w0, logits_out_w1, logits_out_s = outputs['logits_out'].chunk(3)
        
        # 1) Inlier classifier
        L_fix = consistency_loss(logits_x_ulb_s, p, 'ce', mask=mask)
        
        # 2) Outlier detector
        # -- Open-set entropy minimization
        Lo_oem = (ova_ent(logits_out_w0) + ova_ent(logits_out_w1)) / 2.
        # -- Soft consistency regularization (SOCR)
        Lo_socr = ova_socr(logits_out_w0, logits_out_w1)
        # -- Negative
        Lo_neg = ova_ulb(logits_out_s, neg_mask)
        
        L_ova = self.lambda_oem * Lo_oem + self.lambda_socr * Lo_socr + self.lambda_neg * Lo_neg

        # TOTAL LOSS
        loss = L_fix + self.lambda_ova * L_ova

        self.optimizer.zero_grad()
        loss.backward()
        if self.clip_grad > 0:
            total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
        self.optimizer.step()

        res_dict = {'loss': float(loss.item()),
                    'L_fix': float(L_fix.item()),
                    'L_ova': float(L_ova.item()),
                    }
        return res_dict


    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.server_round = config["server_round"]
        set_seed(self.seed + self.server_round + self.cid)
        
        model = copy.deepcopy(self.model)
        if self.use_ema_pseudo and (self.server_round > self.start_ema_pseudo):
            ema_ckpt = torch.load(self.ema_save_path, map_location=self.device, weights_only=True)
            model.load_state_dict(ema_ckpt)
            print("==> Load ema ckpt for pseudo-labeling **")
        model.eval()

        dataset_dict = defaultdict(list)
        stats = {
                'tot_samples': 0,
                'tot_correct': 0,
                'pmasked_samples': 0,
                'pmasked_correct': 0,
                'masked_samples': 0,
                'masked_correct': 0,
                'id_count': 0,
                'ood_count': 0
                }
        
        for data in self.train_loader:
            batch = self.process_all_batch(**data)
            idx_ulb = batch['idx_ulb']
            x_ulb = batch['x_ulb']
            y_ulb = batch['y_ulb']
            
            with torch.no_grad():
                output = self.model(x_ulb)
                
                # classifier pseudo labels
                logits = output['logits']
                p = F.softmax(logits / self.T, dim=-1)
                p_mask = masking(p, softmax_x_ulb=False, cutoff=self.p_cutoff)
                
                # out detector
                logits_out = output['logits_out']
                logits_out = logits_out.view(logits_out.shape[0], 2, -1)
                in_p = inlier_p(logits_out)
                
                # masking
                max_probs, pseudo_label = torch.max(p, dim=1)
                selected_in_p = in_p[torch.arange(len(p)), pseudo_label]
                final_mask = ((max_probs >= self.p_cutoff) & (selected_in_p >= self.in_cutoff)).float()
                neg_mask = (in_p < self.neg_cutoff).float()
                
                # stats
                tot_samples, tot_correct, pmasked_samples, pmasked_correct = compute_pseudo_accuracy(p, y_ulb, p_mask)
                _, _, masked_samples, masked_correct = compute_pseudo_accuracy(p, y_ulb, final_mask)

            stats['tot_samples'] += tot_samples
            stats['tot_correct'] += tot_correct
            stats['pmasked_samples'] += pmasked_samples
            stats['pmasked_correct'] += pmasked_correct
            stats['masked_samples'] += masked_samples
            stats['masked_correct'] += masked_correct
            
            id_mask = (y_ulb < self.num_classes)
            ood_mask = ~id_mask

            stats['id_count'] += id_mask.sum().item()
            stats['ood_count'] += ood_mask.sum().item()
            
            dataset_dict['idx_ulb'].append(idx_ulb)
            dataset_dict['p'].append(p.detach())
            dataset_dict['in_p'].append(in_p.detach())
            dataset_dict['p_mask'].append(p_mask.detach())
            dataset_dict['final_mask'].append(final_mask.detach())
            dataset_dict['neg_mask'].append(neg_mask.detach())
            
        final_dict = {k: torch.cat(v, dim=0).cpu() for k, v in dataset_dict.items()}

        self.train_loader = get_dataloader(PrecomputedDataset(
            self.train_loader.dataset, final_dict['idx_ulb'],
            final_dict['p'], final_dict['in_p'],
            final_dict['p_mask'], final_dict['final_mask'], final_dict['neg_mask']),
            batch_size=self.data_cfgs['bs'], shuffle=True,
            num_workers=self.data_cfgs['num_workers'], drop_last=True)

        mask = final_dict['p_mask'] == 1 if self.use_pmask else final_dict['final_mask'] == 1
        final_dict_f = {k: v[mask] for k, v in final_dict.items()}
        filtered_train_loader = get_dataloader(PrecomputedDataset(
            self.train_loader.dataset, final_dict_f['idx_ulb'],
            final_dict_f['p'], final_dict_f['in_p'],
            final_dict_f['p_mask'], final_dict_f['final_mask'],
            final_dict_f['neg_mask']),
            batch_size=self.data_cfgs['f_bs'], shuffle=True,
            num_workers=self.data_cfgs['num_workers'], drop_last=True)
        
        curr_lr = float(config.get("current_lr", self.train_cfgs['lr']))
        for g in self.optimizer.param_groups:
            g['lr'] = curr_lr * self.cls_lr_factor

        total_loss = 0.0
        total_fix = 0.0
        total_ova = 0.0
        batch_count = 0
        
        # TRAIN STEP
        # use all data
        self.model.train()
        for _ in range(self.local_epochs):
            for data in self.train_loader:
                res = self.train_step(**self.process_batch(**data))
                total_loss += res['loss']
                total_fix += res['L_fix']
                total_ova += res['L_ova']
                batch_count += 1

        if (
            filtered_train_loader is not None
            and self.server_round >= self.start_filtered_epochs
        ):
            for _ in range(self.local_filtered_epochs):
                for data in filtered_train_loader:
                    res = self.train_step(**self.process_batch(**data))
                    total_loss += res['loss']
                    total_fix += res['L_fix']
                    total_ova += res['L_ova']
                    batch_count += 1

        self.res_dict = {
            'tot_samples': stats['tot_samples'],
            'tot_pseudo_acc': stats['tot_correct'] / stats['tot_samples']
                                if stats['tot_samples'] > 0 else 0.0,
            'pmasked_samples': stats['pmasked_samples'],
            'pmasked_pseudo_acc': stats['pmasked_correct'] / stats['pmasked_samples']
                                if stats['pmasked_samples'] > 0 else 0.0,
            'masked_samples': stats['masked_samples'],
            'masked_pseudo_acc': stats['masked_correct'] / stats['masked_samples']
                                if stats['masked_samples'] > 0 else 0.0,
            'loss': total_loss / batch_count if batch_count > 0 else 0.0,
            'L_fix': total_fix / batch_count if batch_count > 0 else 0.0,
            'L_ova': total_ova / batch_count if batch_count > 0 else 0.0,
            'id_count': stats['id_count'],
            'ood_count': stats['ood_count']
        }

        if self.filtered_weight:
            num_data = masked_samples
            if num_data == 0:
                num_data = 1
        else:
            num_data = len(self.train_loader.dataset)
        
        return self.get_parameters(), num_data, self.res_dict
    

class PrecomputedDataset(Dataset):
    def __init__(self, orig_dataset, orig_idx, p, in_p, p_mask, final_mask, neg_mask):
        self.orig_dataset = orig_dataset
        self.orig_idx = orig_idx
        self.p = p
        self.in_p = in_p
        self.p_mask = p_mask
        self.final_mask = final_mask
        self.neg_mask = neg_mask

    def __len__(self):
        return len(self.orig_idx)

    def __getitem__(self, i):
        idx = self.orig_idx[i].item() 
        sample = self.orig_dataset[idx] 

        sample.update({
            'p': self.p[i],
            'in_p': self.in_p[i],
            'p_mask': self.p_mask[i],
            'final_mask': self.final_mask[i],
            'neg_mask': self.neg_mask[i],
        })
        return sample