import copy
import numpy as np

import torch
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score

from src.core.base import BaseServer
from src.algorithms.utils import (
    ce_loss, 
    ova_loss, 
    supervised_contrastive_loss,
    proto_contrastive_lb_loss
)
from src.algorithms.network import OursNet
from src.core.utils import set_seed


class OursServer(BaseServer):
    def __init__(self, config, net_builder, train_loader, test_loader, logger):
        
        self.config = config
        # model architecture
        self.cls_hidden = self.config['Model']['cls_hidden']
        self.proj_hidden = self.config['Model']['proj_hidden']
        self.proj_size = self.config['Model']['proj_size']
        self.out_hidden = self.config['Model']['out_hidden']
           
        super().__init__(config, net_builder, train_loader, test_loader, logger)    

        self.ova_w_neg_ratio = self.train_cfgs['ova_w_neg_ratio']
        
        # for prototypes
        self.use_ema_feat = False
        self.ema_feat_round = self.train_cfgs['ema_feat_round']
        self.cont_T = self.train_cfgs['cont_T']
        self.fm_it = self.train_cfgs['fm_it']
        self.it = 0

        self.l_feat_add = torch.zeros(self.num_classes, self.proj_size).to(self.device)
        self.l_sample_cnt = torch.zeros(self.num_classes).to(self.device)
        self.feat_means = None
        self.prototypes = None
        
        self.lambda_ova = self.train_cfgs['lambda_ova']
        
        self.hard_neg = self.train_cfgs['hard_neg']


    def set_model(self):
        model = super().set_model()  # backbone
        model = OursNet(base=model, 
                        num_classes=self.num_classes, 
                        cls_hidden=self.cls_hidden,
                        proj_hidden=self.proj_hidden,
                        proj_size=self.proj_size,
                        out_hidden=self.out_hidden)
        return model


    def set_ema_model(self):
        ema_model = self.net_builder(num_classes=self.num_classes)        
        ema_model = OursNet(base=ema_model, 
                            num_classes=self.num_classes, 
                            cls_hidden=self.cls_hidden,
                            proj_hidden=self.proj_hidden,
                            proj_size=self.proj_size,
                            out_hidden=self.out_hidden)
        ema_model.load_state_dict(self.model.state_dict())
        return ema_model
    
    
    # Prototypes ====================================================================
    
    @torch.no_grad()
    def update_feat_mean_it(self, f, labels):
        feats_l = torch.zeros(self.num_classes, f.size(1), device=f.device)  
        cnts_l = torch.zeros(self.num_classes, device=f.device)

        unique_labels_l, local_cnts_l = labels.unique(return_counts=True)
        feats_l.index_add_(0, labels, f)
        cnts_l.index_add_(0, unique_labels_l, local_cnts_l.float())

        self.l_sample_cnt.add_(cnts_l)
        self.l_feat_add.add_(feats_l)

        if (self.it + 1) % self.fm_it == 0:
            l_valid_cnt = self.l_sample_cnt.unsqueeze(1).clone()
            l_valid_cnt[l_valid_cnt == 0] = 1  
            self.l_feat_means = self.l_feat_add.div(l_valid_cnt)
            
            self.feat_means = self.l_feat_means
            self.prototypes = self.feat_means.clone().detach()
            
            # reset
            self.l_feat_add.zero_()
            self.l_sample_cnt.zero_()   
            
            
    @torch.no_grad()
    def update_feat_mean(self, f, labels, update=False):
        if not update:
            feats_l = torch.zeros(self.num_classes, f.size(1), device=f.device)  
            cnts_l = torch.zeros(self.num_classes, device=f.device)

            unique_labels_l, local_cnts_l = labels.unique(return_counts=True)
            feats_l.index_add_(0, labels, f)
            cnts_l.index_add_(0, unique_labels_l, local_cnts_l.float())

            self.l_sample_cnt.add_(cnts_l)
            self.l_feat_add.add_(feats_l)

        if update:
            l_valid_cnt = self.l_sample_cnt.unsqueeze(1).clone()
            l_valid_cnt[l_valid_cnt == 0] = 1  
            self.l_feat_means = self.l_feat_add.div(l_valid_cnt)
            
            self.feat_means = self.l_feat_means
            self.prototypes = self.feat_means.clone().detach()
            
            # reset
            self.l_feat_add.zero_()
            self.l_sample_cnt.zero_()   


    def update_prototype(self):
        if self.use_ema_feat:
            model = self.ema_model
        else:
            model = self.model

        model.eval()
        with torch.no_grad():
            for data in self.train_loader:
                x = data['x_lb_w']
                y_lb = data['y_lb']
                if isinstance(x, dict):
                    x = {k: v.to(self.device) for k, v in x.items()}
                    y_lb = {k: v.to(self.device) for k, v in y_lb.items()}
                else:
                    x = x.to(self.device)
                    y_lb = y_lb.to(self.device)
                feats_x_w = model(x, only_feat=False)['feat_proj']
                self.update_feat_mean(f=feats_x_w.detach(), labels=y_lb)
        model.train()
        
        self.update_feat_mean(f=feats_x_w.detach(), labels=y_lb, update=True)
        
    # ===============================================================================
            
            
    def train_step(self, optimizer, 
                   x_lb, x_lb_w, x_lb_s0, x_lb_s1, y_lb):
        
        self.it += 1
        num_lb = y_lb.shape[0]
        
        self.model.train()
        
        inputs = torch.cat((x_lb_w, x_lb, x_lb_s0, x_lb_s1))
        outputs = self.model(inputs)
        
        # For inlier classifier
        logits_lbs = outputs['logits'][:2*num_lb]
        # For OVA loss
        logits_out_w = outputs['logits_out'][:2*num_lb]
        logits_out_s = outputs['logits_out'][2*num_lb:4*num_lb]
        
        # 1) Inlier classifier
        Lx = ce_loss(logits_lbs, y_lb.repeat(2), reduction='mean')

        # 2) Outlier detector
        Lo = ova_loss(logits_out_w, logits_out_s, y_lb.repeat(2), 
                      w_neg_ratio=self.ova_w_neg_ratio,
                      use_hard_negative=self.hard_neg)
        
        loss = Lx + self.lambda_ova * Lo
        
        optimizer.zero_grad()
        loss.backward()
        if self.clip_grad > 0:
            total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad)
        optimizer.step() 
        
        lr = optimizer.param_groups[0]['lr']
        
        res_dict = {f"{self.mode}_train/s_loss": Lx.item(),
                    f"{self.mode}_train/ova_loss": self.lambda_ova * Lo.item(),
                    f"{self.mode}_train/lr": lr}

        return res_dict


    def _train_loop(self, epochs, optimizer, scheduler, mode):
        
        set_seed(self.seed + self.round)
        self.model.train()
        self.mode = mode
        
        if self.round > self.ema_feat_round:
            self.use_ema_feat = True
        
        self.print_fn(f"Use training ema: {self.use_ema}")
        for e in range(epochs):
            self.epoch += 1
            
            for data in self.train_loader:
                # train step
                self.res_dict = self.train_step(optimizer, **self.process_batch(**data))

            # wandb logging
            if self.use_wandb:
                self.run.log(self.res_dict, step=self.epoch+self.round)
                
            # ema update for warmup -- outside batch loop
            if (self.mode == 'warmup'):
                if self.use_ema:
                    self.ema.update()
                    self.ema_model.load_state_dict(self.ema.shadow, strict=False)
                    torch.save(self.ema_model.state_dict(), self.ema_save_path)

                if e % self.eval_epoch == 0:
                    _ = self.evaluate(mode=self.mode)
                    self.save_model(filename='warmup_latest_model.pth')

        if scheduler is not None:
            scheduler.step()
        
        
    # =========================================================================
    
    def warm_up(self, epochs=500):
        self.print_fn(f"[Server] Warm-up for {epochs} epochs before federated training")
        warm_up_optimizer, _ = self.set_optimizer(mode='warmup')
        self._train_loop(epochs, 
                         optimizer=warm_up_optimizer, scheduler=None, mode="warmup")
        if self.static_bn:
            self.apply_static_bn()
            self.print_fn("[Server] Updating static BN before sending parameters")
        self.update_prototype()
            
            
    def fine_tune(self, epochs=5):
        self.print_fn(f"[Server] Fine-tuning for {epochs} epochs")
        if self.static_bn:
            self.freeze_bn_stats()
        self._train_loop(epochs, 
                         optimizer=self.optimizer, scheduler=self.scheduler, mode="finetune")

        if self.use_ema:
            self.ema.update()
            self.ema_model.load_state_dict(self.ema.shadow, strict=False)
            torch.save(self.ema_model.state_dict(), self.ema_save_path)
            self.print_fn("[!] Update and Save EMA model")
            
        if self.static_bn:
            self.apply_static_bn()
            self.print_fn("[Server] Updating static BN before sending parameters")
            
        self.update_prototype()
            
                
    @torch.no_grad()
    def evaluate(self, mode="warmup"):
        self.print_fn(f">> eval round: {self.round}")
        self.print_fn(f">> eval epochs: {self.epoch}")
        
        self.print_fn(f"== use eval ema: {self.use_ema}")

        model = self.ema_model if self.use_ema else self.model
            
        model.eval()

        total_loss = 0.0
        total_num = 0

        y_true = []
        y_in_true = []
        
        y_pred = []
        y_out = []
        y_pred_p = []
        y_out_p = []

        use_prototypes = self.prototypes is not None
        if use_prototypes:
            prototypes = F.normalize(self.prototypes, dim=1)

        with torch.no_grad():
            for data in self.test_loader:
                x = data['x_lb']
                y = data['y_lb']

                if isinstance(x, dict):
                    x = {k: v.to(self.device) for k, v in x.items()}
                else:
                    x = x.to(self.device)
                y = y.to(self.device)

                in_idx = torch.where(y < self.num_classes)[0]
                y_true.extend(y.cpu().tolist())
                
                outputs = model(x)
                logits = outputs['logits']
                logits_out = outputs['logits_out']
                feats = outputs['feat_proj']

                batch_indices = torch.arange(logits.size(0), device=logits.device)
                pred_all = torch.argmax(logits, dim=1)
            
                if use_prototypes:
                    feat_norm = F.normalize(feats, dim=1)
                    sim = torch.matmul(feat_norm, self.prototypes.T)
                    sim_max, pred_all_p = torch.max(sim, dim=1)
            
                # Classification
                if len(in_idx) > 0:
                    y_in_true.extend(y[in_idx].cpu().tolist())
                    
                    # == classifier
                    y_pred.extend(pred_all[in_idx].cpu().tolist())
                    loss = F.cross_entropy(logits[in_idx], y[in_idx], reduction='mean')
                    total_loss += loss.item() * in_idx.shape[0]
                    total_num += in_idx.shape[0]

                    # == prototype
                    if use_prototypes:
                        y_pred_p.extend(pred_all_p[in_idx].cpu().tolist())
            
                # OOD Detection
                # == ova classifier
                logits_out = logits_out.view(logits_out.size(0), 2, -1)   # [B, 2, C]
                probs_out = F.softmax(logits_out, dim=1)
                unk_score = probs_out[batch_indices, 0, pred_all]
                kn_score = probs_out[batch_indices, 1, :]
                y_out.extend(unk_score.cpu().tolist())

                if len(in_idx) > 0:
                    y_in = y[in_idx]                            # [B_in]
                    kn_score_in = kn_score[in_idx]              # [B_in, C]
                    target_onehot = F.one_hot(y_in, num_classes=self.num_classes).float()  # [B_in, C]
                
                # == prototype
                if use_prototypes:
                    unk_score_p = 1 - sim_max
                    y_out_p.extend(unk_score_p.cpu().tolist())
            
            
        # Metrics
        y_true = np.array(y_true)
        y_in_true = np.array(y_in_true)
        y_pred = np.array(y_pred)
        y_out = np.array(y_out)

        if use_prototypes and len(y_in_true) > 0:
            y_pred_p = np.array(y_pred_p)
            y_out_p = np.array(y_out_p)
        else:
            y_pred_p = np.zeros_like(y_pred)
            y_out_p = np.zeros_like(y_out)

        # == Accuracy
        top1 = accuracy_score(y_in_true, y_pred) if len(y_in_true) > 0 else 0.0
        top1_p = accuracy_score(y_in_true, y_pred_p) if (use_prototypes and len(y_in_true) > 0) else 0.0

        # AUROC 
        roc_labels = (y_true >= self.num_classes).astype(int)
        try:
            auroc_ova = roc_auc_score(roc_labels, y_out)
            auroc_proto = roc_auc_score(roc_labels, y_out_p) if use_prototypes else 0.0
        except ValueError:
            auroc_ova = 0.0
            auroc_proto = 0.0
        
        # Confusion matrix
        cf_mat = None
        if self.num_classes <= 10 and len(y_in_true) > 0:
            cf_mat = confusion_matrix(y_in_true, y_pred, normalize='true')
            self.print_fn("confusion matrix:\n" + np.array_str(cf_mat))

        model.train()

        eval_dict = {
            f"{mode}/cls_loss": total_loss / total_num if total_num > 0 else 0.0,
            f"{mode}/top-1-acc": top1,
            f"{mode}/auroc_ova": auroc_ova,
            f"{mode}/top-1-acc_proto": top1_p,
            f"{mode}/auroc_proto": auroc_proto,
            f"{mode}/round": self.round,
        }

        if mode == 'agg':
            self.save_model(filename="agg_latest_model.pth")
            if top1 > self.best_acc:
                self.save_model(filename="agg_best_model.pth")
                self.best_acc = top1
                self.best_round = self.round
                self.print_fn("[*] Save best agg. model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_acc
            eval_dict[f"{mode}/best_round"] = self.best_round
        elif mode == 'finetune':
            self.save_model(filename="finetune_latest_model.pth")
            if top1 > self.best_fine_acc:
                self.save_model(filename="fnetune_best_model.pth")
                self.best_fine_acc = top1
                self.best_fine_round = self.round
                self.print_fn("[*] Save best finetune model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_fine_acc
            eval_dict[f"{mode}/best_round"] = self.best_fine_round

        if self.use_wandb:
            self.run.log(eval_dict, step=self.epoch + self.round)

        self.print_fn(eval_dict)
        return eval_dict