import itertools
import torch
import torch.nn as nn
import math

import logging

txt_logger = logging.getLogger("sfda_reg")


def fea_extractor_bn_setup(model, fea_bn_mode="eval", e=0):
    """"
    set feature extractor's bn layer fixed
    """
    for module in model.modules():
        if isinstance(module, nn.modules.batchnorm._BatchNorm):
            match fea_bn_mode:
                case 'eval':
                    module.eval()
                    for name, param in module.named_parameters():
                        # print(f"param name is {name}.")
                        param.requires_grad = False
                        param.grad = None
                case 'train':
                    module.train()
                    for name, param in module.named_parameters():
                        # print(f"param name is {name}.")
                        param.requires_grad = True
                case 'train_m':
                    module.train()
                    module.momentum += 0.05
                    for param in module.parameters():
                        param.requires_grad = True
            momentum = module.momentum
    if e == 0:
        txt_logger.info(f"Feature Extractor bn mode is [{fea_bn_mode}], momentum is {momentum:.4f}.")


def grad_update_switch(net: nn.Module, flag: bool, keywords: list = None, e: int = 0):
    """
    Enable or disable gradient updates for the given model or specific modules based on keywords.
    """
    if e == 0:
        txt_logger.info(f"module param grad setup: {keywords} = {flag}")
    if keywords is None:
        for param in net.parameters():
            param.requires_grad = flag
            if not flag:
                param.grad = None
    else:
        for name, module in net.named_modules():
            if any(keyword in name for keyword in keywords):
                # if 'classifier_fea' in name and flag == False:
                #     module.eval()
                #     print(f"{name} is under eval mode.")
                for p_name, param in module.named_parameters():
                    param.requires_grad = flag
                    if not flag:
                        param.grad = None


def create_optimizer(net, config, name="optimizer") -> torch.optim.Optimizer:
    optimizer_config = config[name]

    param_groups = []
    txt_logger.info(f"Creating Optimizer: [{name}]...")
    bn_mode = ''
    modules_states = {"train": [], "frozen": []}
    # feature extractor
    if not optimizer_config['feature_extractor'].get('freeze', True):
        txt_logger.info(
            f"-> Feature_extractor is fully trainable. lr is {optimizer_config['feature_extractor']['lr']}, w_d is {optimizer_config['feature_extractor']['weight_decay']}."
        )
        param_groups.append(
            {
                'params': net.get_feature_extractor().parameters(),
                'lr': optimizer_config['feature_extractor']['lr'],
                'weight_decay': optimizer_config['feature_extractor']['weight_decay'],
                'name': 'feature_extractor'
            })
        bn_mode = "train"
        modules_states["train"].append("feature_extractor")
    else:
        for param in net.get_feature_extractor().parameters():
            param.requires_grad = False
        bn_mode = "eval"
        modules_states["frozen"].append("feature_extractor")
        # fea frozen but BN active
        if optimizer_config['feature_extractor'].get('train_bn', True):
            bn_params = []
            bn_mode = "train"
            for module in net.get_feature_extractor().modules():
                if isinstance(module, nn.modules.batchnorm._BatchNorm):
                    for param in module.parameters():
                        param.requires_grad = True
                        bn_params.append(param)

            if bn_params:
                param_groups.append(
                    {
                        'params':
                        bn_params,
                        'lr':
                        optimizer_config['feature_extractor']['lr'],
                        'weight_decay':
                        optimizer_config['feature_extractor']['weight_decay'],
                        'name':
                        'feature_extractor',
                    })
            txt_logger.info(
                f"-> Feature_extractor is frozen, but BN is trainable. "
                f"lr is {optimizer_config['feature_extractor']['lr']}, "
                f"w_d is {optimizer_config['feature_extractor']['weight_decay']}.")
        else:
            txt_logger.info(f"-> Feature_extractor is frozen.")

    # regressor
    if not optimizer_config['regressor'].get('freeze', True):
        modules_states["train"].append("regressor")
        txt_logger.info(
            f"-> Regressor is fully trainable. "
            f"lr is {optimizer_config['regressor']['lr']}, "
            f"w_d is {optimizer_config['regressor']['weight_decay']}")
        param_groups.append(
            {
                'params': net.get_regressor().parameters(),
                'lr': optimizer_config['regressor']['lr'],
                'weight_decay': optimizer_config['regressor']['weight_decay'],
                'name': 'regressor',
            })
    else:
        modules_states["frozen"].append("regressor")
        txt_logger.info(f"-> Regressor is frozen.")
        for param in net.get_regressor().parameters():
            param.requires_grad = False

    # classifier
    if not optimizer_config['classifier'].get('freeze', False):
        modules_states["train"].append("classifier")
        txt_logger.info(
            f"-> Classifier is fully trainable."
            f"lr is {optimizer_config['classifier']['lr']}, "
            f"w_d is {optimizer_config['classifier']['weight_decay']}.")
        param_groups.append(
            {
                'params':
                itertools.chain(
                    net.get_classifier_fea().parameters(),
                    net.get_classifier_cls().parameters(),
                ),
                'lr':
                optimizer_config['classifier']['lr'],
                'weight_decay':
                optimizer_config['classifier']['weight_decay'],
                'name':
                "classifier",
            })
    else:
        txt_logger.info(f"-> Classifier is frozen.")
        modules_states["frozen"].append("classifier")
        for param in net.get_classifier_fea().parameters():
            param.requires_grad = False
        for param in net.get_classifier_cls().parameters():
            param.requires_grad = False
    txt_logger.info(f"-> bn_mode is {bn_mode}, train components are {modules_states['train']}, frozen components are {modules_states['frozen']}")
    txt_logger.info(f"-> optimizer [{optimizer_config['name']}] is initialized.")
    
    opt_name = optimizer_config['name']
    if opt_name == "SGD":
        for param_dict in param_groups:
            param_dict['momentum'] = 0.9
            param_dict['nesterov'] = True
    opt = eval(f"torch.optim.{opt_name}")(param_groups)
    return opt, bn_mode, modules_states


def create_group_schedulers(optimizer, config, name="optimizer_cls") -> dict | None:
    """
    Create group schedulers for each parameter group in the optimizer.
    Returns:
        schedulers: dict
    """
    scheduler_config = config[name].get("scheduler", None)

    component_name = [p['name'] for p in optimizer.param_groups]
    txt_logger.info(f"Creating Scheduler for optimizer: [{name}]...")
    txt_logger.info(f"Components are: [{component_name}]")

    if scheduler_config is None:
        txt_logger.info(
            f"-> No scheduler for optimizer: [{name}] with components [{component_name}].")
        return None

    scheduler_type = scheduler_config['type']

    if scheduler_type == 'StepLR':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=scheduler_config['step_size'],
            gamma=scheduler_config['gamma'])
        txt_logger.info(
            f"-> params of opt {name}: Using StepLR (step_size={scheduler_config['step_size']}, gamma={scheduler_config['gamma']})."
        )
    elif scheduler_type == 'ExponentialLR':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=scheduler_config['gamma'])
        txt_logger.info(
            f"-> params of opt {name}: Using ExponentialLR (gamma={scheduler_config['gamma']})."
        )
    elif scheduler_type == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=scheduler_config['T_max'],
            eta_min=scheduler_config.get('eta_min', 1e-5))
        txt_logger.info(
            f"-> params of opt {name}: Using CosineAnnealingLR (T_max={scheduler_config['T_max']}, eta_min={scheduler_config.get('eta_min', 0)})."
        )

    else:
        raise ValueError(
            f"-> Unsupported scheduler type for params of {name}: {scheduler_type}")

    txt_logger.info(f"-> Init lr: {scheduler.get_last_lr()}.")

    return scheduler


def create_optimizers_schedulers_main(
        net, config, name_opt="optimizer") -> torch.optim.Optimizer:
    optimizer_config = config[name_opt]
    optimizers = {}
    bn_modes = {}
    schedulers = {}
    modules_states = {}
    txt_logger.info(
        f"Creating [Optimizer, Scheduler, BN Mode Dict and Module Grad Dict] for main training with yaml: =={name_opt}=="
    )
    modules = list(optimizer_config.keys())
    txt_logger.info(f"Modules are: {modules}")
    for m in modules:
        opt_m, bn_mode_m, states_m = create_optimizer(net, optimizer_config, name=m)
        scheduler_m = create_group_schedulers(opt_m, optimizer_config, name=m)
        optimizers[m] = opt_m
        bn_modes[m] = bn_mode_m
        schedulers[m] = scheduler_m
        modules_states[m] = states_m
        txt_logger.info(f"------------- [{m}] is finished. -------------")
    return optimizers, bn_modes, schedulers, modules_states
