import numpy as np

import torch
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score

from src.core.base import BaseServer
from src.algorithms.utils import ce_loss, ova_loss, ova_ent, ova_socr
from src.algorithms.network import OpenNet


class OpenMatchServer(BaseServer):
    def __init__(self, config, net_builder, train_loader, test_loader, logger):
        
        self.config = config
        # model hidden
        self.cls_hidden = self.config['Model']['cls_hidden']
        self.out_hidden = self.config['Model']['out_hidden']
        self.mlp = self.config['Model']['mlp']
           
        super().__init__(config, net_builder, train_loader, test_loader, logger)    

        self.ova_w_neg_ratio = self.train_cfgs['ova_w_neg_ratio']
        
        self.lambda_ova = self.train_cfgs['lambda_ova']
        self.hard_neg = self.train_cfgs['hard_neg']
        

    def set_model(self):
        model = super().set_model()  # backbone
        model = OpenNet(base=model, 
                        num_classes=self.num_classes, 
                        cls_hidden=self.cls_hidden,
                        out_hidden=self.out_hidden,
                        mlp=self.mlp)
        return model


    def set_ema_model(self):
        ema_model = self.net_builder(num_classes=self.num_classes)        
        ema_model = OpenNet(base=ema_model, 
                            num_classes=self.num_classes, 
                            cls_hidden=self.cls_hidden,
                            out_hidden=self.out_hidden,
                            mlp=self.mlp)
        ema_model.load_state_dict(self.model.state_dict())
        return ema_model
    
    
    def train_step(self, optimizer, 
                   x_lb, x_lb_w, x_lb_s0, x_lb_s1, y_lb):
        
        self.model.train()

        num_lb = y_lb.shape[0]
        inputs = torch.cat((x_lb_w, x_lb, x_lb_s0, x_lb_s1))
        outputs = self.model(inputs)
        
        logits_lbs = outputs['logits'][:2*num_lb]
        
        # For OVA loss
        logits_out_w = outputs['logits_out'][:2*num_lb]
        logits_out_s = outputs['logits_out'][2*num_lb:4*num_lb]
        
        # 1) Inlier classifier
        Lx = ce_loss(logits_lbs, y_lb.repeat(2), reduction='mean')
        
        # 2) Outlier detector
        Lo = ova_loss(logits_out_w, logits_out_s, y_lb.repeat(2), 
                      w_neg_ratio=self.ova_w_neg_ratio,
                      use_hard_negative=self.hard_neg)
        
        loss = Lx + self.lambda_ova * Lo

        optimizer.zero_grad()
        loss.backward() 
        if self.clip_grad > 0:
            total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad)
        optimizer.step() 
        
        lr = optimizer.param_groups[0]['lr']
        
        res_dict = {f"{self.mode}_train/s_loss": Lx.item(),
                    f"{self.mode}_train/ova_loss": Lo.item(),
                    f"{self.mode}_train/lr": lr}
        
        return res_dict
    

    @torch.no_grad()
    def evaluate(self, mode="warmup"):
        self.print_fn(f">> eval round: {self.round}")
        self.print_fn(f">> eval epochs: {self.epoch}")
        
        model = self.ema_model if self.use_ema else self.model
        model.eval()

        total_loss = 0.0
        total_num = 0.0

        y_true = []
        y_in_true = []
        y_pred = []
        y_out_score = []
    
        with torch.no_grad():
            for data in self.test_loader:
                x = data['x_lb']
                y = data['y_lb']

                if isinstance(x, dict):
                    x = {k: v.to(self.device) for k, v in x.items()}
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                
                outputs = model(x)
                logits = outputs['logits']
                logits_out = outputs['logits_out']
            
                # Standard classification
                in_idx = torch.where(y < self.num_classes)[0]
                if len(in_idx) > 0:
                    preds = torch.max(logits[in_idx], dim=-1)[1]
                    loss = F.cross_entropy(logits[in_idx], y[in_idx], reduction='mean')
                    total_loss += loss.item() * in_idx.shape[0]
                    total_num += in_idx.shape[0]

                    y_in_true.extend(y[in_idx].cpu().tolist())
                    y_pred.extend(preds.cpu().tolist())

                y_true.extend(y.cpu().tolist())

                # Outlier detection from ova classifier
                logits_out = logits_out.view(logits_out.size(0), 2, -1)  # [B, 2, C]
                probs_out = F.softmax(logits_out, dim=1)  # over [inlier, outlier]
                pred_class = torch.argmax(logits, dim=1)
                batch_indices = torch.arange(logits.size(0), device=logits.device)
                unk_score = probs_out[batch_indices, 0, pred_class]  # outlier score = P(outlier)
                y_out_score.extend(unk_score.cpu().tolist())                

        # metrics
        y_true = np.array(y_true)
        y_in_true = np.array(y_in_true)
        y_pred = np.array(y_pred)
        y_out_score = np.array(y_out_score)

        top1 = accuracy_score(y_in_true, y_pred) if len(y_in_true) > 0 else 0.0

        # AUROC using outlier scores
        roc_labels = (y_true >= self.num_classes).astype(int)  # 1 = outlier
        try:
            auroc_ova = roc_auc_score(roc_labels, y_out_score)
        except ValueError:
            auroc_ova = 0.0

        if self.num_classes <= 10 and len(y_in_true) > 0:
            cf_mat = confusion_matrix(y_in_true, y_pred, normalize='true')
            self.print_fn("confusion matrix:\n" + np.array_str(cf_mat))

        model.train()        
        
        eval_dict = {
            f"{mode}/cls_loss": total_loss / total_num if total_num > 0 else 0.0,
            f"{mode}/top-1-acc": top1,
            f"{mode}/auroc_ova": auroc_ova,
            f"{mode}/round": self.round
        }
    
        # Logging
        if mode == 'agg':
            self.save_model(filename="agg_latest_model.pth")
            if top1 > self.best_acc:
                self.save_model(filename="agg_best_model.pth")
                self.best_acc = top1
                self.best_round = self.round
                self.print_fn("[*] Save best agg. model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_acc
            eval_dict[f"{mode}/best_round"] = self.best_round
        elif mode == 'finetune':
            self.save_model(filename="finetune_latest_model.pth")
            if top1 > self.best_fine_acc:
                self.save_model(filename="fnetune_best_model.pth")
                self.best_fine_acc = top1
                self.best_fine_round = self.round
                self.print_fn("[*] Save best finetune model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_fine_acc
            eval_dict[f"{mode}/best_round"] = self.best_fine_round

        if self.use_wandb:
            self.run.log(eval_dict, step=self.epoch + self.round)

        self.print_fn(eval_dict)
        return eval_dict