import torch
from torch import nn
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.models import resnet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.fcn import FCN, FCNHead
from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3
from .system_def import get_device
from typing import Dict
import math
import os
from pprint import pprint

import pdb
import sys
sys.path.append('../../')
from defaults import transformers

class BaseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = get_device()
    
    def attr_from_dict(self, param_dict):
        for key in param_dict:
            setattr(self, key, param_dict[key])
            
    def get_out_channels(self, m):
        def children(m): return m if isinstance(m, (list, tuple)) else list(m.children())
        c=children(m)
        if len(c)==0: return None
        for l in reversed(c):
            if hasattr(l, 'num_features'): return l.num_features
            res = self.get_out_channels(l)
            if res is not None: return res
            
    def get_submodel(self, m, min_layer=None, max_layer=None):
        return list(m.children())[min_layer:max_layer]
    
    def freeze_bn(self, submodel=None):
        submodel = self if submodel is None else submodel
        for layer in submodel.modules():
            if isinstance(layer,  nn.BatchNorm2d):
                layer.eval()
                
    def unfreeze_bn(self, submodel=None):
        submodel = self if submodel is None else submodel
        for layer in submodel.modules():
            if isinstance(layer,  nn.BatchNorm2d):
                layer.train()
                
    def freeze_submodel(self, submodel=None):
        submodel = self if submodel is None else submodel
        for param in submodel.parameters():
            param.requires_grad = False
            
    def unfreeze_submodel(self, submodel=None):
        submodel = self if submodel is None else submodel
        for param in submodel.parameters():
            param.requires_grad = True

    def initialize_norm_layers(self, submodel=None):
        submodel = self if submodel is None else submodel
        for layer in submodel.modules():
            if isinstance(layer,  nn.BatchNorm2d) or isinstance(layer,  nn.GroupNorm):
                layer.weight.data.fill_(1)
                layer.bias.data.zero_()  

    def freeze_norm_layers(self, submodel=None):
        submodel = self if submodel is None else submodel
        for layer in submodel.modules():
            if isinstance(layer,  nn.BatchNorm2d) or isinstance(layer,  nn.GroupNorm):
                layer.eval()  
                
    def init_weights(self, submodel=None):
        submodel = self if submodel is None else submodel
        for layer in submodel.modules():
            if isinstance(layer,  nn.Conv2d):
                nn.init.kaiming_normal_(layer.weight.data)
                
    def BN_to_GN(self, submodel=None, num_groups=32, keep_stats=True):
        def get_atr(m,n):
            try:
                a = getattr(m, n)
                return a
            except:
                return m[int(n)]
        def recur_depth(submodel,lname, n=0, keep_stats=True, num_groups=32):
            if n < len(lname)-1:
                return recur_depth(get_atr(submodel,lname[n]),
                                   lname, n=n+1, keep_stats=keep_stats, num_groups=num_groups)
            else:
                old_l = getattr(submodel, lname[n])
                nc = old_l.num_features
                new_l = nn.GroupNorm(num_groups=num_groups, num_channels=nc)
                if keep_stats:
                    new_l.weight = old_l.weight
                    new_l.bias = old_l.bias
                setattr(submodel, lname[n], new_l)
                
        submodel = self if submodel is None else submodel
        for name, module in submodel.named_modules():
            if isinstance(module,  nn.BatchNorm2d):
                recur_depth(submodel,name.split('.'), keep_stats=keep_stats, num_groups=num_groups)

    def print_trainable_params(self, submodel=None):
        submodel = self if submodel is None else submodel
        for name, param in submodel.named_parameters():
            if param.requires_grad:
                print(name)  
                
                
class TransConvIntermediateLayerGetter(IntermediateLayerGetter):
    """
    FROM https://github.com/pytorch/vision/blob/1cb85abe1e35ea83c8082877502f1d2007b21c7d/torchvision/models/_utils.py#L7
    """
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str], single_head=True) -> None:
        super().__init__(model, return_layers)
        self.is_transformer = isinstance(model, transformers.VisionTransformer)
        if self.is_transformer:
            self.single_head = single_head
            self.cls_token = model.cls_token
            self.pos_embed = model.pos_embed
            self.interpolate_pos_encoding = model.interpolate_pos_encoding        

    def forward(self, x):
        out = OrderedDict()  
        
        if self.is_transformer:

            w_featmap = x.shape[-2] // self.patch_embed.patch_size
            h_featmap = x.shape[-1] // self.patch_embed.patch_size
            
            B = x.shape[0]
            x = self.patch_embed(x)

            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
            x = x + pos_embed
            x = self.pos_drop(x)

            for i, blk in enumerate(self.blocks):
                if i < len(self.blocks) - 1:
                    x = blk(x)
                else:
                    x =  blk(x, return_attention=True)       
                    
                    # we keep only the output patch attention
                    if self.single_head:
                        x = x[:, :, 0, 1:]
                        x = x.reshape(-1, x.size(1), w_featmap, h_featmap)                        

            out['out'] = x
                   
        else:        
            for name, module in self.items():
                x = module(x)
                if name in self.return_layers:
                    out_name = self.return_layers[name]
                    out[out_name] = x
        
        return out                
                
        
def state_dict_from_path(path):
    """
    Returns the state dict from a checkpoint path
    """
    path = os.path.abspath(path)
    mname = os.path.basename(path)
    if os.path.isfile(path):
        print("Loading weights from \"{}\"".format(mname))
        return torch.load(path)['state_dict']
    else:
        dirname = os.path.dirname(path)
        raise FileNotFoundError(
            "Model \"{}\" is not present in \"{}\"".format(mname, dirname))

def load_from_checkpoint(model, path, strict=False):
    """
    Loads wights to the input model from a pretrained model path
    """
    pretrained_state = state_dict_from_path(path)
    pretrained_state = OrderedDict([(k.split("backbone.")[-1], v) for k,v in pretrained_state.items()])
    dif_keys = model.load_state_dict(pretrained_state, strict=strict)
    dif_keys = set([" : ".join(key.split(".")[:2]) for key in dif_keys.unexpected_keys])
    if dif_keys:
        print("Unmatched pretrained modules")
        pprint(dif_keys)
        
def tansconv_segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True, 
                        single_head=True, transformers_params={}, transfer_learning_params={}):
    """
    FROM https://github.com/pytorch/vision/blob/1cb85abe1e35ea83c8082877502f1d2007b21c7d/torchvision/models/segmentation/segmentation.py#L27
    """
    if 'resnet' in backbone_name:
        backbone = resnet.__dict__[backbone_name](
            pretrained=pretrained_backbone,
            replace_stride_with_dilation=[False, True, True])
        out_layer = 'layer4'
        out_inplanes = 2048
        aux_layer = 'layer3'
        aux_inplanes = 1024
    elif 'deit' in backbone_name:
        backbone = transformers.__dict__[backbone_name](**transformers_params, 
                                                                      pretrained=pretrained_backbone)
        out_layer = 'blocks' 
        if single_head:
            out_inplanes = backbone.blocks[0].attn.num_heads
        else:
            out_inplanes = backbone.blocks[0].attn.num_heads
        aux_layer = None
        aux_inplanes = None
    else:
        raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))

    return_layers = {out_layer: 'out'}
    if aux:
        return_layers[aux_layer] = 'aux'
    backbone = TransConvIntermediateLayerGetter(backbone, return_layers=return_layers, single_head=single_head)

    aux_classifier = None
    if aux:
        aux_classifier = FCNHead(aux_inplanes, num_classes)

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    classifier = model_map[name][0](out_inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model

class SegNet(BaseModel):
    def __init__(self, model_params, single_head=True):
        super().__init__()
        
        self.model_params = model_params
        self.attr_from_dict(model_params)         

        fcn_backbone = tansconv_segm_model(name=self.segmentation_type, 
                                    backbone_name=self.backbone_type, 
                                    num_classes=self.n_classes, 
                                    aux=False, 
                                    pretrained_backbone=self.pretrained,
                                    single_head=single_head,
                                    transformers_params=self.model_params["transformers_params"],
                                    transfer_learning_params=self.model_params["transfer_learning_params"])  
        
        self.modify_first_layer(fcn_backbone, self.backbone_type, self.img_channels, self.pretrained)
        if self.model_params["transfer_learning_params"]["use_pretrained"]:
            pretrained_model_name = self.model_params["transfer_learning_params"]["pretrained_model_name"]
            pretrained_path = self.model_params["transfer_learning_params"]["pretrained_path"]
            pretrained_path = os.path.join(pretrained_path, pretrained_model_name)
            print("\033[1mLoading pretrained model : {}\033[0m".format(pretrained_model_name))
            load_from_checkpoint(fcn_backbone.backbone, pretrained_path, strict=False)  
        
        if self.freeze_backbone:
            self.freeze_submodel(fcn_backbone.backbone)
        self.backbone = fcn_backbone

        if self.goup_norm['replace_with_goup_norm']:
            self.BN_to_GN(num_groups=self.goup_norm['num_groups'],
                          keep_stats=self.goup_norm['keep_stats'])
            
    def modify_first_layer(self, backbone, backbone_type, img_channels, pretrained):
        if img_channels == 3:
            return
        
        if 'resnet' in backbone_type:
            conv_attrs = ['out_channels', 'kernel_size', 'stride', 
                          'padding', 'dilation', "groups", "bias", "padding_mode"]
            conv1_defs = {attr: getattr(backbone.backbone.conv1, attr) for attr in conv_attrs}

            pretrained_weight = backbone.backbone.conv1.weight.data
            pretrained_weight = pretrained_weight.repeat(1, 4, 1, 1)[:, :img_channels]

            backbone.backbone.conv1 = nn.Conv2d(img_channels, **conv1_defs)
            if pretrained:
                backbone.backbone.conv1.weight.data = pretrained_weight 
                
        elif 'deit' in backbone_type:
            patch_embed_attrs = ["img_size", "patch_size"]
            patch_defs = {attr: getattr(backbone.backbone.patch_embed, attr) for attr in patch_embed_attrs}
            patch_defs["embed_dim"] = backbone.backbone.patch_embed.proj.out_channels

            pretrained_weight = backbone.backbone.patch_embed.proj.weight.data
            if backbone.backbone.patch_embed.proj.bias is not None:
                pretrained_bias = backbone.backbone.patch_embed.proj.bias.data
            pretrained_weight = pretrained_weight.repeat(1, 4, 1, 1)[:, :img_channels]
            
            backbone.backbone.patch_embed = transformers.PatchEmbed(in_chans=img_channels, **patch_defs)
            if pretrained:
                backbone.backbone.patch_embed.proj.weight.data = pretrained_weight 
                if backbone.backbone.patch_embed.proj.bias is not None:
                    backbone.backbone.patch_embed.proj.bias.data = pretrained_bias                  
        
        else:
            raise NotImplementedError("channel modification is not implemented for {}".format(backbone_type))            
            
    def forward(self, x):
        if self.freeze_backbone:
            self.backbone.backbone.eval()        
        x = self.backbone(x)
        return x
    