from typing import Any, Optional
from collections.abc import Iterator

from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.feature_extraction import create_feature_extractor
import timm
import torch
from torch.autograd import Function
import logging

txt_logger = logging.getLogger("sfda_reg")

def create_regressor(config: dict[str, Any], dropout=False):
    data_type = config['data_type']
    match data_type:
        case "image":
            backbone = config['net_backbone']
            pretrained = config['net_pretrained']
            if dropout:
                net = CNNRegressor_DROPOUT(backbone, pretrained)
            else:
                net = CNNRegressor(backbone, pretrained)
        case "table":
            net_info = config['net_info']
            in_dims = net_info[0]
            h_dims = net_info[1]
            n_rep = net_info[2]
            if dropout:
                net = MLPRegressor_DROPOUT(in_dims, h_dims, n_rep)
            else:
                net = MLPRegressor(in_dims, h_dims, n_rep)
        case _ as t:
            raise ValueError(f"Invalid type: {t!r}")

    return net


def extract_bn_layers(mod: nn.Module) -> Iterator[_BatchNorm]:
    for m in mod.children():
        if isinstance(m, _BatchNorm):
            yield m
        else:
            yield from extract_bn_layers(m)


# =====================================================
# ================== basic regressors =================
# =====================================================
class Regressor(nn.Module):
    regressor: nn.Linear

    def feature(self, x: Tensor) -> Tensor:
        raise NotImplementedError

    def predict_from_feature(self, z: Tensor) -> Tensor:
        raise NotImplementedError

    def get_regressor(self) -> nn.Module:
        raise NotImplementedError

    def get_feature_extractor(self) -> nn.Module:
        raise NotImplementedError

    def forward(self, x: Tensor) -> Tensor:
        z = self.feature(x)
        y_pred = self.predict_from_feature(z)
        return y_pred


class Regressor_DROPOUT(nn.Module):
    regressor: nn.Linear

    def feature(self, x: Tensor) -> Tensor:
        raise NotImplementedError

    def predict_from_feature(self, z: Tensor) -> Tensor:
        raise NotImplementedError

    def get_regressor(self) -> nn.Module:
        raise NotImplementedError

    def get_feature_extractor(self) -> nn.Module:
        raise NotImplementedError

    def forward(self, x: Tensor, dropout_prob_temps: float | None=None, training: Any=None) -> Tensor:
        z = self.feature(x)
        y_pred = self.predict_from_feature(z, dropout_prob_temps, training)
        return y_pred



# SSA basic models
class CNNRegressor(Regressor):
    def __init__(self, backbone: str, pretrained: bool, in_channels: int=3, reg_num: int=1):
        super().__init__()

        match backbone:
            case "resnet26":
                base_net = timm.create_model("resnet26", pretrained=pretrained)
                if in_channels != 3:
                    base_net.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7,
                                               stride=2, padding=3, bias=False)
                self.feature_extractor = create_feature_extractor(
                    base_net, {"global_pool": "feature"})

            case "resnet50":
                weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
                base_net = resnet50(weights=weights)
                if in_channels != 3:
                    base_net.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7,
                                               stride=2, padding=3, bias=False)
                self.feature_extractor = create_feature_extractor(
                    base_net, {"avgpool": "feature"})

            case _:
                raise ValueError(f"Invalid backbone: {backbone!r}")

        self.regressor = nn.Linear(2048, reg_num)
        self.reg_num = reg_num
        self.fea_dim = 2048

    def feature(self, x: Tensor) -> Tensor:
        z: Tensor = self.feature_extractor(x)["feature"]
        return z.flatten(start_dim=1)

    def predict_from_feature(self, z: Tensor) -> Tensor:
        y_pred: Tensor = self.regressor(z)
        return y_pred.flatten()

    def get_regressor(self) -> nn.Module:
        return self.regressor

    def get_feature_extractor(self) -> nn.Module:
        return self.feature_extractor


class MLPRegressor(Regressor):
    def __init__(self, in_dims: int, h_dims: int, n_rep: int):
        super().__init__()

        self.feature_extractor = nn.Sequential(
            nn.Linear(in_dims, h_dims, bias=False),
            nn.BatchNorm1d(h_dims),
            nn.ReLU(),
            *(
                nn.Sequential(
                    nn.Linear(h_dims, h_dims, bias=False),
                    nn.BatchNorm1d(h_dims),
                    nn.ReLU()
                )
                for _ in range(n_rep)
            )
        )

        self.regressor = nn.Linear(h_dims, 1)
        self.reg_num = 1
        self.fea_dim = h_dims

    def feature(self, x: Tensor) -> Tensor:
        if x.ndim >= 3:
            x = x.flatten(start_dim=1)
        return self.feature_extractor(x)

    def predict_from_feature(self, z: Tensor) -> Tensor:
        y_pred: Tensor = self.regressor(z)
        return y_pred.flatten()

    def get_regressor(self) -> nn.Linear:
        return self.regressor

    def get_feature_extractor(self) -> nn.Module:
        return self.feature_extractor


# ===================================================
# ================== Dropout Models =================
# ===================================================
class CNNRegressor_DROPOUT(Regressor_DROPOUT):
    def __init__(self, backbone: str, pretrained: bool, in_channels: int=3, reg_num: int=1):
        super().__init__()

        match backbone:
            case "resnet26":
                base_net = timm.create_model("resnet26", pretrained=pretrained)
                if in_channels != 3:
                    base_net.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7,
                                               stride=2, padding=3, bias=False)
                self.feature_extractor = create_feature_extractor(
                    base_net, {"global_pool": "feature"})

            case "resnet50":
                weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
                base_net = resnet50(weights=weights)
                if in_channels != 3:
                    base_net.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7,
                                               stride=2, padding=3, bias=False)
                self.feature_extractor = create_feature_extractor(
                    base_net, {"avgpool": "feature"})

            case _:
                raise ValueError(f"Invalid backbone: {backbone!r}")

        self.regressor = nn.Linear(2048, reg_num)
        self.reg_num = reg_num
        self.fea_dim = 2048

    def feature(self, x: Tensor) -> Tensor:
        z: Tensor = self.feature_extractor(x)["feature"]
        return z.flatten(start_dim=1)

    def predict_from_feature(self, z: Tensor, dropout_prob_temps: float | None=None, training: Any=None) -> Tensor:
        do_p = 0.8 if dropout_prob_temps is None else dropout_prob_temps
        training_flag = False if training is None else training
        
        z = F.dropout(z, p=do_p, training=training_flag)
        y_pred: Tensor = self.regressor(z)
        return y_pred.flatten()

    def get_regressor(self) -> nn.Module:
        return self.regressor

    def get_feature_extractor(self) -> nn.Module:
        return self.feature_extractor


class MLPRegressor_DROPOUT(Regressor_DROPOUT):
    def __init__(self, in_dims: int, h_dims: int, n_rep: int):
        super().__init__()

        self.feature_extractor = nn.Sequential(
            nn.Linear(in_dims, h_dims, bias=False),
            nn.BatchNorm1d(h_dims),
            nn.ReLU(),
            *(
                nn.Sequential(
                    nn.Linear(h_dims, h_dims, bias=False),
                    nn.BatchNorm1d(h_dims),
                    nn.ReLU()
                )
                for _ in range(n_rep)
            )
        )

        self.regressor = nn.Linear(h_dims, 1)
        self.reg_num = 1
        self.fea_dim = h_dims

    def feature(self, x: Tensor) -> Tensor:
        if x.ndim >= 3:
            x = x.flatten(start_dim=1)
        return self.feature_extractor(x)

    def predict_from_feature(self, z: Tensor, dropout_prob_temps: float | None=None, training: Any=None) -> Tensor:
        do_p = 0.8 if dropout_prob_temps is None else dropout_prob_temps
        training_flag = False if training is None else training
        
        z = F.dropout(z, p=do_p, training=training_flag)
        y_pred: Tensor = self.regressor(z)
        return y_pred.flatten()

    def get_regressor(self) -> nn.Linear:
        return self.regressor

    def get_feature_extractor(self) -> nn.Module:
        return self.feature_extractor

        

class CNNRegressorClassifier_DROPOUT(CNNRegressor_DROPOUT):
    def __init__(self, backbone: str, pretrained: bool=False, in_channels: int=3, 
                 reg_num: int=1, num_classes: int=10,
                 base_model: CNNRegressor_DROPOUT=None):
        
        if base_model is not None:
            super().__init__(backbone, False, in_channels, reg_num)
            self.feature_extractor.load_state_dict(
                base_model.get_feature_extractor().state_dict()
            )
            self.regressor.load_state_dict(
                base_model.get_regressor().state_dict()
            )
        else:
            super().__init__(backbone, pretrained, in_channels, reg_num)
        
        # cls_fea_dim = 256
        self.classifier_fea = nn.Sequential(nn.Linear(self.fea_dim, 256), 
                                        nn.BatchNorm1d(256))
        self.classifier_cls = nn.Linear(256, num_classes)
        self.num_classes = num_classes
        
    def predict_classification_from_feature(self, z: Tensor, dropout_prob_temps: float | None=None, 
                                            training: Any=None) -> Tensor:
        do_p = 0.8 if dropout_prob_temps is None else dropout_prob_temps
        training_flag = False if training is None else training
        
        # z = F.dropout(z, p=do_p, training=training_flag)
        class_fea: Tensor = self.classifier_fea(z)
        class_pred: Tensor = self.classifier_cls(class_fea)
        return class_pred
    
    def cls_feature(self, x: Tensor) -> Tensor:
        return self.classifier_fea(self.feature(x))
    
    def cls_feature_from_feature(self, z: Tensor) -> Tensor:
        return self.classifier_fea(z)
    
    def predict_classification_from_cls_feature(self, cls_z: Tensor, dropout_prob_temps: float | None=None, 
                                            training: Any=None) -> Tensor:
        do_p = 0.8 if dropout_prob_temps is None else dropout_prob_temps
        training_flag = False if training is None else training
        
        # z = F.dropout(z, p=do_p, training=training_flag)
        class_pred: Tensor = self.classifier_cls(cls_z)
        return class_pred
    
    def forward(self, x: Tensor, dropout_prob_temps: float | None=None, 
                training: Any=None) -> tuple[Tensor, Tensor]:
        features = self.feature(x)
        regression = self.predict_from_feature(
            features, dropout_prob_temps, training
        )
        classification = self.predict_classification_from_feature(
            features, dropout_prob_temps, training
        )
        
        return regression, classification
    
    def get_classifier_fea(self) -> nn.Module:
        return self.classifier_fea
    
    def get_classifier_cls(self) -> nn.Module:
        return self.classifier_cls


class MLPRegressorClassifier_DROPOUT(MLPRegressor_DROPOUT):
    
    def __init__(self,
                 in_dims: int, 
                 h_dims: int, 
                 n_rep: int,
                 num_classes: int=10,
                 base_model: MLPRegressor=None,):
        
        if base_model is not None:
            super().__init__(in_dims, h_dims, n_rep)
            self.feature_extractor.load_state_dict(
                base_model.get_feature_extractor().state_dict()
            )
            self.regressor.load_state_dict(
                base_model.get_regressor().state_dict()
            )
        else:
            super().__init__(in_dims, h_dims, n_rep)
        
        cls_fea_dim = 256
        self.classifier_fea = nn.Sequential(nn.Linear(self.fea_dim, cls_fea_dim), 
                                        nn.BatchNorm1d(cls_fea_dim))
        self.classifier_cls = nn.Linear(cls_fea_dim, num_classes)
        self.num_classes = num_classes
    
    def predict_classification_from_feature(self, z: Tensor, dropout_prob_temps: float | None=None, 
                                            training: Any=None) -> Tensor:
        class_fea: Tensor = self.classifier_fea(z)
        class_pred: Tensor = self.classifier_cls(class_fea)
        return class_pred
    
    def cls_feature(self, x: Tensor) -> Tensor:
        return self.classifier_fea(self.feature(x))
    
    def cls_feature_from_feature(self, z: Tensor) -> Tensor:
        return self.classifier_fea(z)
    
    def predict_classification_from_cls_feature(self, cls_z: Tensor, dropout_prob_temps: float | None=None, 
                                            training: Any=None) -> Tensor:
        class_pred: Tensor = self.classifier_cls(cls_z)
        return class_pred
    
    def forward(self, x: Tensor, dropout_prob_temps: float | None=None, 
                training: Any=None) -> tuple[Tensor, Tensor]:
        features = self.feature(x)
        regression = self.predict_from_feature(
            features, dropout_prob_temps, training
        )
        classification = self.predict_classification_from_feature(
            features, dropout_prob_temps, training
        )
        
        return regression, classification
    
    def get_classifier_fea(self) -> nn.Module:
        return self.classifier_fea
    
    def get_classifier_cls(self) -> nn.Module:
        return self.classifier_cls
    
# =====================================================
# DANN models with layer grad reverse
# =====================================================
class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None


def grad_reverse(x, lambda_=1.0):
    return GradientReversalFunction.apply(x, lambda_)

class CNNRegressorClassifier_DANN(CNNRegressor):
    def __init__(self, backbone: str, pretrained: bool=False, in_channels: int=3, 
                 reg_num: int=1, num_classes: int=2,
                 base_model: CNNRegressor=None,
                 lambda_grl: float = 1.0):
        
        if base_model is not None:
            super().__init__(backbone, False, in_channels, reg_num)
            self.feature_extractor.load_state_dict(
                base_model.get_feature_extractor().state_dict()
            )
            self.regressor.load_state_dict(
                base_model.get_regressor().state_dict()
            )
        else:
            super().__init__(backbone, pretrained, in_channels, reg_num)
        
        self.classifier_fea = nn.Sequential(nn.Linear(self.fea_dim, 256), 
                                        nn.BatchNorm1d(256))
        self.classifier_cls = nn.Linear(256, num_classes)
        self.num_classes = num_classes
        self.lambda_grl: float = lambda_grl
        
    def predict_classification_from_feature(self, z: Tensor) -> Tensor:
        class_fea: Tensor = self.classifier_fea(z)
        class_pred: Tensor = self.classifier_cls(class_fea)
        return class_pred
    
    def cls_feature(self, x: Tensor) -> Tensor:
        return self.classifier_fea(self.feature(x))
    
    def cls_feature_from_feature(self, z: Tensor) -> Tensor:
        return self.classifier_fea(z)
    
    def predict_classification_from_cls_feature(self, cls_z: Tensor) -> Tensor:
        class_pred: Tensor = self.classifier_cls(cls_z)
        return class_pred
    
    def forward(self, x: Tensor, return_domain_pred=False, lambda_grl=None) -> tuple[Tensor, Tensor]:
        z = self.feature(x)
        y_pred = self.regressor(z).flatten()
        if return_domain_pred:
            if lambda_grl is None:
                domain_pred = self.predict_classification_from_feature(grad_reverse(z, self.lambda_grl))
            else:
                domain_pred = self.predict_classification_from_feature(grad_reverse(z, lambda_grl))
            return y_pred, domain_pred
        return y_pred
    
    def get_classifier_fea(self) -> nn.Module:
        return self.classifier_fea
    
    def get_classifier_cls(self) -> nn.Module:
        return self.classifier_cls


class MLPRegressorClassifier_DANN(MLPRegressor):
    def __init__(self, 
                 in_dims: int, 
                 h_dims: int, 
                 n_rep: int,
                 num_classes: int=2,
                 base_model: MLPRegressor=None,
                 lambda_grl: float = 1.0):
        
        if base_model is not None:
            super().__init__(in_dims, h_dims, n_rep)
            self.feature_extractor.load_state_dict(
                base_model.get_feature_extractor().state_dict()
            )
            self.regressor.load_state_dict(
                base_model.get_regressor().state_dict()
            )
        else:
            super().__init__(in_dims, h_dims, n_rep)
        
        
        cls_fea_dim = 256
        self.classifier_fea = nn.Sequential(nn.Linear(self.fea_dim, cls_fea_dim), 
                                        nn.BatchNorm1d(cls_fea_dim))
        self.classifier_cls = nn.Linear(cls_fea_dim, num_classes)
        self.num_classes = num_classes
        self.lambda_grl: float = lambda_grl
        
    def predict_classification_from_feature(self, z: Tensor) -> Tensor:
        class_fea: Tensor = self.classifier_fea(z)
        class_pred: Tensor = self.classifier_cls(class_fea)
        return class_pred
    
    def cls_feature(self, x: Tensor) -> Tensor:
        return self.classifier_fea(self.feature(x))
    
    def cls_feature_from_feature(self, z: Tensor) -> Tensor:
        return self.classifier_fea(z)
    
    def predict_classification_from_cls_feature(self, cls_z: Tensor) -> Tensor:
        class_pred: Tensor = self.classifier_cls(cls_z)
        return class_pred
    
    def forward(self, x: Tensor, return_domain_pred=False, lambda_grl=None) -> tuple[Tensor, Tensor]:
        z = self.feature(x)
        y_pred = self.regressor(z).flatten()
        if return_domain_pred:
            if lambda_grl is None:
                domain_pred = self.predict_classification_from_feature(grad_reverse(z, self.lambda_grl))
            else:
                domain_pred = self.predict_classification_from_feature(grad_reverse(z, lambda_grl))
            return y_pred, domain_pred
        return y_pred
    
    def get_classifier_fea(self) -> nn.Module:
        return self.classifier_fea
    
    def get_classifier_cls(self) -> nn.Module:
        return self.classifier_cls
    
# =====================================================
# weight learner
# =====================================================
    
class LossWeightLearner(nn.Module):
    def __init__(self, num_losses=3, init_value=1.0):
        super(LossWeightLearner, self).__init__()
        self.log_vars = nn.Parameter(torch.ones(num_losses) * torch.tensor(init_value))
        
    def get_weights(self):
        clamped_log_vars = torch.clamp(self.log_vars, min=-10.0, max=10.0)
        weights = torch.exp(-clamped_log_vars)
        
        # 防止NaN/Inf
        if torch.isnan(weights).any() or torch.isinf(weights).any():
            txt_logger.warning("warning for NaN/Inf weight")
            return torch.ones_like(weights) / len(weights)
        return weights
    
    def get_weights_detached(self):
        clamped_log_vars = torch.clamp(self.log_vars.detach().clone(), min=-10.0, max=10.0)
        weights = torch.exp(-clamped_log_vars)
        
        # 防止NaN/Inf
        if torch.isnan(weights).any() or torch.isinf(weights).any():
            txt_logger.warning("warning for NaN/Inf weight")
            return torch.ones_like(weights) / len(weights)
        return weights
    
    def weighted_loss(self, *losses):
        weights = self.get_weights()
        assert len(losses) == len(weights), f"loss len ({len(losses)}) and weight len ({len(weights)}) do not match"
        

        weighted_losses = [w * l for w, l in zip(weights, losses)]
        total_loss = sum(weighted_losses)
        reg_term = 0.5 * self.log_vars.sum()
        
        return total_loss + reg_term, weights
    
    def forward(self, losses):
        weights = self.get_weights()
        weighted_losses = weights * losses
        
        # reg term
        # loss = weighted_loss + log_var/2
        reg_term = 0.5 * self.log_vars
        
        return weighted_losses.sum() + reg_term.sum(), weights