import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import faiss
import torchvision.transforms as transforms

import sys

from .backbones.utils import NormalizeByChannelMeanStd

MEAN_STD_PER_DATASET = {
    "cifar10": ((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),    
    "cifar100": ((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
    "stl10": ((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
}


def get_base_model(dataset, backbone, hidden_dim, use_normalize, zero_init_residual=True, use_projector=False, multi_bn=False):
    assert 'resnet' in backbone, "we only support ResNet models!!!"

    if multi_bn:
        assert backbone in ['resnet18'], 'Multi-Batch Normalization is only supported for ResNet18'
                
        if dataset in ['cifar10', 'cifar100']:
            if use_normalize:
                from .backbones.resnet_multi_bn_add_normalize import resnet18
            else:
                from .backbones.resnet_multi_bn import resnet18
        elif dataset == 'stl10':
            if use_normalize:
                from .backbones.resnet_multi_bn_stl10_add_normalize import resnet18
            else:
                from .backbones.resnet_multi_bn_stl10 import resnet18
        else:
            raise NotImplementError

        bn_names = ['normal', 'pgd']    
        base_model = resnet18(pretrained=False,
                              bn_names=bn_names,
                              num_classes=hidden_dim,
                              zero_init_residual=zero_init_residual)
        if use_normalize:
            mean, std = MEAN_STD_PER_DATASET[dataset]
            base_model.normalize = NormalizeByChannelMeanStd(mean=mean , std=std)

    else:
        if use_normalize:
            from .backbones.resnet_add_normalize import resnet18_NormalizeInput

            base_model = resnet18_NormalizeInput(num_classes=hidden_dim,
                                                 zero_init_residual=zero_init_residual)
            mean, std = MEAN_STD_PER_DATASET[dataset]
            base_model.normalize = NormalizeByChannelMeanStd(mean=mean , std=std)
            
        else:
            base_model = torchvision.models.__dict__[backbone](num_classes=hidden_dim, 
                                                               zero_init_residual=zero_init_residual)

        if dataset in ['cifar10', 'cifar100'] and 'resnet' in backbone:
            conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=3, bias=False)
            nn.init.kaiming_normal_(conv1.weight,
                                    mode='fan_out',
                                    nonlinearity='relu')
            base_model.conv1 = conv1
            base_model.maxpool = nn.Identity()
        elif dataset == 'stl10':
            print("Dataset: STL10 => Use Original ResNet18 Arch")
        else:
            raise NotImplementError
            

    if use_projector:        
        raise NotImplementError      
#         prev_dim = base_model.fc.in_features
#         base_model.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False),
#                                       nn.BatchNorm1d(prev_dim),
#                                       nn.ReLU(inplace=True),
#                                       nn.Linear(prev_dim, hidden_dim, bias=False),
#                     #                 nn.BatchNorm1d(hidden_dim, affine=False))
#                                       nn.BatchNorm1d(hidden_dim)
#                                      )
    else:
        base_model.fc = nn.Identity()

    return base_model


def load_pretrained_model(model, ckpt_path, use_projector=False, multi_bn=False):
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    state_dict = checkpoint['state_dict']

    state_dict_keys = list(state_dict.keys())
    
    prefix = ''
    if 'backbone' in state_dict_keys[0]:
        prefix = 'backbone.'
    elif 'encoder' in state_dict_keys[0]:
        prefix = 'encoder.'

    if prefix:
        for k in state_dict_keys:
            new_key = k.replace(prefix, '')
            if not use_projector and 'fc' in new_key:
                pass
            else:
                state_dict[new_key] = state_dict[k]
            del state_dict[k]

    if multi_bn:
        normal_index = 0
        state_dict_keys = list(state_dict.keys())
        for k in state_dict_keys:
            new_key = ''
            if 'downsample.0' in k:
                new_key = k.replace('downsample.0', 'downsample.conv')
            elif 'downsample.1' in k:
                new_key = k.replace('downsample.1', f'downsample.bn.bn_list.{normal_index}')
            elif 'bn' in k:
                point = k.index('bn')
                new_key = k[:point + 4] + f'bn_list.{normal_index}.' + k[point + 4:]
            if new_key:
                state_dict[new_key] = state_dict[k]
                del state_dict[k]
        
        msg = model.load_state_dict(state_dict, strict=False)
        adv_index = normal_index ^ 1

        for missing_key in msg.missing_keys:
            if missing_key == 'normalize.mean' or missing_key == 'normalize.std':
                continue
            assert f'bn_list.{adv_index}' in missing_key, f"Incorrect Batch Normalization layers are loaded"
        
        assert not msg.unexpected_keys, f'Unexpected Key List {msg.unexpected_keys}'

    else:
        msg = model.load_state_dict(state_dict, strict=False)
        
        for missing_key in msg.missing_keys:
            print(missing_key)
            if missing_key == 'normalize.mean' or missing_key == 'normalize.std':
                continue
            else:
                raise 

    print(f'load_state_dict result: {msg}')

    return model