import os
import copy
import wandb
import numpy as np

from inspect import signature
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score

import torch
import torch.nn.functional as F

from src.core.utils import (
    get_optimizer,
    get_custom_cosine_scheduler,
    EMA,
    set_seed
)


class BaseServer:
    def __init__(self, 
                 config, 
                 net_builder, 
                 train_loader, test_loader, logger=None):
        
        # Config
        self.config = config
        self.data_cfgs = self.config['Dataset']
        self.model_cfgs = self.config['Model']
        self.strategy_cfgs = config['Training']['Strategy']
        self.train_cfgs = config['Training']['Server']
        
        self.seed = self.config['seed']
        set_seed(self.seed)
        
        self.save_path = os.path.join(self.config['save_dir'], self.config['save_name'])
        
        # Exp. device setting
        self.gpu = self.config['gpu']
        self.device = torch.device(f"cuda:{self.gpu}" if torch.cuda.is_available() else "cpu")
        print(self.device)
        
        # Set Dataset
        self.train_loader = train_loader
        self.test_loader = test_loader
        
        # Set Model
        self.num_classes = self.data_cfgs['num_classes']
        self.net_builder = net_builder
        self.model = self.set_model()
        self.model.to(self.device)
        
        self.ema_model = self.set_ema_model()
        self.ema_model.to(self.device)
            
        self.use_ema = self.model_cfgs['use_ema']
        if self.use_ema:
            self.ema_save_path = os.path.join(self.save_path, 'ema_model.pth')
            self.ema_m = self.model_cfgs['ema_m']
            self.ema = EMA(self.model, self.ema_m)
            self.ema.register()
        
        # Training
        self.optimizer, self.scheduler = self.set_optimizer(mode='finetune')  
        self.use_scheduler = self.config['Training']['use_scheduler']
        self.static_bn = self.config['Training']['static_bn']
        self.clip_grad = self.config['Training']['clip']
        
        self.tot_round = self.config['Training']['total_round']
        self.round = 0
        
        self.warmup_epochs = self.train_cfgs['warmup_epochs']
        self.finetune_epochs = self.train_cfgs['finetune_epochs']
        self.eval_epoch = self.train_cfgs[f'warmup_eval_epoch']
        self.epoch = 0
        
        self.best_acc, self.best_round = 0.0, 0
        self.best_fine_acc, self.best_fine_round = 0.0, 0
        
        # Logging & Results
        self.print_fn = print if logger is None else logger.info
        
        # == wandb init == 
        self.use_wandb = self.config['use_wandb']
        if self.use_wandb:
            name = self.config['save_name']
            project = self.config['save_dir'].split('/')[-1]
            # tags
            benchmark = f"benchmark: {project}"
            dataset = f"dataset: {self.data_cfgs['dataset']}"
            data_setting = f"setting: {self.data_cfgs['dataset']}_lb{self.data_cfgs['num_labels']}"
            tags = [benchmark, dataset, data_setting] 
            resume = 'never'
            
            save_dir = os.path.join(self.save_path, 'wandb')
            if not os.path.exists(save_dir):
                os.makedirs(save_dir, exist_ok=True)

            self.run = wandb.init(entity=self.config['wandb_entity'],
                                  name=name, 
                                  tags=tags, 
                                  project=project, 
                                  resume=resume,
                                  dir=save_dir)
            
    
    # ============================================================================================================== #    
    def set_model(self):
        model = self.net_builder(num_classes=self.num_classes, 
                                 pretrained=self.config['Model']['use_pretrain'],
                                 pretrained_path=self.config['Model']['pretrain_path'])
        return model
        
        
    def set_ema_model(self):
        ema_model = self.net_builder(num_classes=self.num_classes)
        ema_model.load_state_dict(self.model.state_dict())
        return ema_model
    
        
    def set_optimizer(self, mode='finetune'):
        if mode == 'warmup':
            optimizer = get_optimizer(net=self.model, 
                                      optim_name=self.train_cfgs['optim'], 
                                      lr=self.train_cfgs['warmup_lr'], 
                                      momentum=self.train_cfgs['warmup_momentum'], 
                                      weight_decay=self.train_cfgs['warmup_weight_decay'])
            scheduler = None
        elif mode == 'finetune':
            optimizer = get_optimizer(net=self.model, 
                                      optim_name=self.train_cfgs['optim'], 
                                      lr=self.train_cfgs['finetune_lr'], 
                                      momentum=self.train_cfgs['finetune_momentum'], 
                                      weight_decay=self.train_cfgs['finetune_weight_decay'])
            scheduler = get_custom_cosine_scheduler(optimizer, 
                                                    first_cycle_step=self.train_cfgs['first_cycle_step'], 
                                                    cycle_mult=self.train_cfgs['cycle_mult'], 
                                                    max_lr=self.train_cfgs['finetune_lr'], 
                                                    min_lr=self.train_cfgs['min_lr'],
                                                    warmup_steps=self.train_cfgs['warmup_steps'],
                                                    gamma=self.train_cfgs['lr_gamma'], 
                                                    last_epoch=-1)
        return optimizer, scheduler


    def apply_static_bn(self):
        self.model.train()
        with torch.no_grad():
            for m in self.model.modules():
                if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                    m.momentum = None
                    m.track_running_stats = True  
            for data in self.train_loader:
                x = data['x_lb']
                if isinstance(x, dict):
                    x = {k: v.to(self.device) for k, v in x.items()}
                else:
                    x = x.to(self.device)
                _ = self.model(x)  
        self.print_fn("====> Static BN statistics updated.")
        
        
    def process_batch(self, **kwargs):
        input_args = signature(self.train_step).parameters
        input_args = list(input_args.keys())
        input_dict = {}

        for arg, var in kwargs.items():
            if not arg in input_args:
                continue
            if var is None:
                continue
            # send var to cuda
            if isinstance(var, dict):
                var = {k: v.to(self.device) for k, v in var.items()}
            else:
                var = var.to(self.device)
            input_dict[arg] = var
        return input_dict
        
        
    def warm_up(self, epochs=10):
        self.print_fn(f"[Server] Warm-up for {epochs} epochs before federated training")
        warm_up_optimizer, _ = self.set_optimizer(mode='warmup')
        self._train_loop(epochs, 
                         optimizer=warm_up_optimizer, scheduler=None, mode="warmup")
        if self.static_bn:
            self.apply_static_bn()
            self.print_fn("[Server] Updating static BN before sending parameters")


    def fine_tune(self, epochs=5):
        self.print_fn(f"[Server] Fine-tuning for {epochs} epochs")
        if self.static_bn:
            self.freeze_bn_stats()
        self._train_loop(epochs, 
                         optimizer=self.optimizer, scheduler=self.scheduler, mode="finetune")
        
        if self.use_ema:
            self.ema.update()
            self.ema_model.load_state_dict(self.ema.shadow, strict=False)
            torch.save(self.ema_model.state_dict(), self.ema_save_path)
            print("[!] Update and Save EMA model")
            
        if self.static_bn:
            self.apply_static_bn()
            self.print_fn("[Server] Updating static BN before sending parameters")
        

    def train_step(self, optimizer, **kwargs):
        """
        train_step specific to each algorithm
        """
        # [!] implement train step for each algorithm
        # compute loss
        # update model 
        # record res_dict
        # return res_dict
        raise NotImplementedError
    
    
    def _train_loop(self, epochs, optimizer, scheduler, mode):
        
        set_seed(self.seed + self.round)
        self.model.train()
        self.mode = mode
        
        self.print_fn(f"[*] Use training ema: {self.use_ema}")
        for e in range(epochs):
            self.epoch += 1
            
            for data in self.train_loader:
                # train step
                self.res_dict = self.train_step(optimizer, **self.process_batch(**data))
                
            # wandb logging
            if self.use_wandb:
                self.run.log(self.res_dict, step=self.epoch+self.round)
                
            # ema update for warmup -- outside batch loop
            if (self.mode == 'warmup'):
                if self.use_ema:
                    self.ema.update()
                    # self.ema_model.load_state_dict(self.model.state_dict())
                    self.ema_model.load_state_dict(self.ema.shadow, strict=False)
                    torch.save(self.ema_model.state_dict(), self.ema_save_path)
            
                if e % self.eval_epoch == 0:
                    _ = self.evaluate(mode=self.mode)
                    self.save_model(filename='warmup_latest_model.pth')

        if scheduler is not None:
            scheduler.step()


    def get_save_dict(self):
        save_dict = {'model': self.model.state_dict(),
                     'ema_model': self.ema_model.state_dict(),
                     'optimizer': self.optimizer.state_dict(),
                     'epoch': self.epoch,
                     'round': self.round}
        return save_dict 
    
    
    def save_model(self, filename="server_model.pth"):
        save_path = os.path.join(self.save_path, filename)
        save_dict = self.get_save_dict()
        torch.save(save_dict, save_path)
        self.print_fn(f"[*] Server model saved to: {save_path}")


    @torch.no_grad()
    def evaluate(self, mode="warmup"):
        self.print_fn(f">> eval round: {self.round}")
        self.print_fn(f">> eval epochs: {self.epoch}")
        
        model = self.ema_model if self.use_ema else self.model
            
        model.eval()

        total_loss = 0.0
        total_num = 0.0

        y_true = []
        y_in_true = []
        y_pred = []

        ood_scores_msp = []
        ood_scores_entropy = []
        ood_scores_energy = []
        ood_labels = []
    
        with torch.no_grad():
            for data in self.test_loader:
                x = data['x_lb']
                y = data['y_lb']

                if isinstance(x, dict):
                    x = {k: v.to(self.device) for k, v in x.items()}
                else:
                    x = x.to(self.device)
                y = y.to(self.device)
                
                outputs = model(x)
                logits = outputs['logits']
            
                # Standard classification
                in_idx = torch.where(y < self.num_classes)[0]
                if len(in_idx) > 0:
                    preds = torch.max(logits[in_idx], dim=-1)[1]
                    loss = F.cross_entropy(logits[in_idx], y[in_idx], reduction='mean')
                    total_loss += loss.item() * in_idx.shape[0]
                    total_num += in_idx.shape[0]

                    y_in_true.extend(y[in_idx].cpu().tolist())
                    y_pred.extend(preds.cpu().tolist())

                y_true.extend(y.cpu().tolist())
                ood_labels.extend((y >= self.num_classes).int().cpu().tolist())

                # OOD scores
                probs = F.softmax(logits, dim=1)
                max_prob, _ = probs.max(dim=1)
                entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
                energy = -torch.logsumexp(logits, dim=1)

                ood_scores_msp.extend((1.0 - max_prob).cpu().tolist())
                ood_scores_entropy.extend(entropy.cpu().tolist())
                ood_scores_energy.extend(energy.cpu().tolist())                
                
        # Accuracy
        y_in_true = np.array(y_in_true)
        y_pred = np.array(y_pred)
        top1 = accuracy_score(y_in_true, y_pred) if len(y_in_true) > 0 else 0.0

        if self.num_classes <= 10 and len(y_in_true) > 0:
            cf_mat = confusion_matrix(y_in_true, y_pred, normalize='true')
            self.print_fn("confusion matrix:\n" + np.array_str(cf_mat))

        # AUROC
        try:
            auroc_msp = roc_auc_score(ood_labels, ood_scores_msp)
            auroc_entropy = roc_auc_score(ood_labels, ood_scores_entropy)
            auroc_energy = roc_auc_score(ood_labels, ood_scores_energy)
        except ValueError:
            auroc_msp = auroc_entropy = auroc_energy = 0.0

        model.train()                
        
        # Logging
        eval_dict = {
            mode+'/loss': total_loss / total_num if total_num > 0 else 0.0,
            mode+'/top-1-acc': top1,
            mode+'/auroc_msp': auroc_msp,
            mode+'/auroc_entropy': auroc_entropy,
            mode+'/auroc_energy': auroc_energy,
            mode+'/round': self.round
            }  

        if mode == 'agg':
            self.save_model(filename="agg_latest_model.pth")
            if top1 > self.best_acc:
                self.save_model(filename="agg_best_model.pth")
                self.best_acc = top1
                self.best_round = self.round
                self.print_fn("[*] Save best agg. model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_acc
            eval_dict[f"{mode}/best_round"] = self.best_round
        elif mode == 'finetune':
            self.save_model(filename="finetune_latest_model.pth")
            if top1 > self.best_fine_acc:
                self.save_model(filename="fnetune_best_model.pth")
                self.best_fine_acc = top1
                self.best_fine_round = self.round
                self.print_fn("[*] Save best finetune model ckpt")
            eval_dict[f"{mode}/best_acc"] = self.best_fine_acc
            eval_dict[f"{mode}/best_round"] = self.best_fine_round

        if self.use_wandb:
            self.run.log(eval_dict, step=self.epoch + self.round)

        self.print_fn(eval_dict)
        return eval_dict


    def load_parameters(self, parameters):
        state_dict = self.model.state_dict()
        for k, v in zip(state_dict.keys(), parameters):
            state_dict[k] = torch.tensor(v)
        self.model.load_state_dict(state_dict)


    def get_model_parameters(self):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]
    
    
    def freeze_bn_stats(self):
        for m in self.model.modules():
            if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                m.eval()  # This prevents running_mean and running_var updates