import abc
import torch
from collections import defaultdict
from layers.registry import BINARY_BATCHNORMS

class BaseModel(abc.ABC, torch.nn.Module):
    """The base class used by all models in this codebase."""

    def __init__(self):
        super(BaseModel, self).__init__()
        self.prunable_modules = None
        self.prev_module = defaultdict()

    def get_model_flops(self):
        '''
        Calcualte Flops for the model
        '''
        n_rem = 0
        n_total = 0
        for l_block in self.prunable_modules:
            k1 = l_block._conv_module.kernel_size[0]
            k2 = l_block._conv_module.kernel_size[1]
            active_elements_count = l_block._conv_module.output_area

            if self.prev_module[l_block] is None:
                prev_total = 3
                prev_remaining = 3
            elif isinstance(self.prev_module[l_block], torch.nn.BatchNorm2d):
                prev_total = self.prev_module[l_block].num_features
                prev_remaining = self.prev_module[l_block].n_remaining(reduction='sum')
            else:
                prev_total = self.prev_module[l_block][-1].num_features
                def cal_max(prev):
                    if isinstance(prev[0], torch.nn.BatchNorm2d):
                        prev1 = prev[0].n_remaining(reduction='none')
                        prev2 = prev[1].n_remaining(reduction='none')
                        return (torch.maximum(prev1, prev2) + torch.maximum(prev2, prev1))/2
                    prev2 = prev[-1].n_remaining(reduction='none')
                    list_ = cal_max(prev[0])
                    return (torch.maximum(list_, prev2) + torch.maximum(prev2, list_))/2
                prev_remaining = cal_max(self.prev_module[l_block]).sum()
            curr_remaining = l_block.n_remaining(reduction='sum')

            # conv
            conv_per_position_flops = k1 * k2 * prev_remaining * curr_remaining
            if isinstance(l_block, BINARY_BATCHNORMS):
                conv_per_position_flops = conv_per_position_flops/64.

            n_rem += conv_per_position_flops * active_elements_count
            if l_block._conv_module.bias is not None:
                n_rem += curr_remaining * active_elements_count
            # bn
            batch_flops = curr_remaining * active_elements_count
            n_rem += batch_flops ## ReLU flops
            n_rem += batch_flops*2

            ## normal 
            # conv
            conv_per_position_flops = k1 * k2 * prev_total * l_block.num_features
            n_total += conv_per_position_flops * active_elements_count
            if l_block._conv_module.bias is not None:
                n_total += l_block.num_features * active_elements_count
            # bn
            batch_flops = l_block.num_features * active_elements_count
            n_total += batch_flops ## ReLU flops
            n_total += batch_flops*2
        return n_rem/n_total
    
    @staticmethod
    @abc.abstractmethod
    def is_valid_model_name(model_name: str) -> bool:
        """Is the model name string a valid name for models in this class?"""
        pass

    @staticmethod
    @abc.abstractmethod
    def get_model_from_config(config):
        """Returns an instance of this class as described by the config."""
        pass

    @staticmethod
    @abc.abstractmethod
    def get_default_args():
        """The default hyperparameters for training this model"""
        pass