import os
from os.path import join

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torch.backends.cudnn as cudnn

from topk_operator import TopK_custom

config = {
    'cifar10': [64, 64, 32, 32, 16, 16, 1*2],
    'cifar100': [64, 64, 32, 32, 16, 16, 1*2],
    'imagenet': [64, 64, 64, 64, 32, 32, 32, 32, 32,
                 16, 16, 16, 16, 16, 16, 1*2],
}


class Encoder(nn.Module):
    def __init__(self, num_classes, dataset_name, device, relu=True, sep=False):
        super(Encoder, self).__init__()
        self.num_classes = num_classes
        self.device = device
        self.relu = relu
        self.sep = sep

        self.encoder = self._make_layers(config[dataset_name])
        if sep:
            self.mu = nn.Sequential(nn.Conv2d(config[dataset_name][-2], 1, 3, 1, 1),
                                    nn.BatchNorm2d(1),
                                    nn.ReLU())
            self.logvar = nn.Conv2d(config[dataset_name][-2],1,3,1,1)

    def _make_layers(self, config):
        encoder_layers = []
        in_channels = 3
        for cfg in config[:-1]:
            encoder_layers += [nn.Conv2d(in_channels, cfg, 3, 1, 1),
                               nn.BatchNorm2d(cfg),
                               nn.ReLU()]
            in_channels = cfg
        if not self.sep:
            if self.relu:
                encoder_layers += [nn.Conv2d(in_channels, config[-1], 3, 1, 1),
                                   nn.BatchNorm2d(config[-1]),
                                   nn.ReLU()]
            else:
                encoder_layers += [nn.Conv2d(in_channels, config[-1], 3, 1, 1)]
        return nn.Sequential(*encoder_layers)
        

    def forward(self, x):
        z = self.encoder(x)
        if not self.sep:
            mu, logvar = z[:,:1,:,:], z[:,1:,:,:]
        else:
            mu = self.mu(z)
            logvar = self.logvar(z)
        return mu, logvar
        
def reparameterize(mu, logvar):
    mu = mu
    std = torch.exp(0.5*logvar) 
    eps = torch.randn_like(std)
    return mu + eps*std


def load_pretrained_base_model(arch, dataset, n_cls, device):
    if arch == 'vgg16':
        model = tv.models.vgg16_bn(pretrained=True)
    elif arch == 'resnet50':
        model = tv.models.resnet50(pretrained=True)
    model = model.to(device)
    model = torch.nn.DataParallel(model)
    cudnn.benchmark=True
    return model

def load_topk_module(device):
    topk_module = TopK_custom(epsilon=1e-3, max_iter=100, device=device)
    topk_module = topk_module.to(device)
    topk_module = torch.nn.DataParallel(topk_module)
    return topk_module

def load_pretrained_encoder(arch, dataset, baseline_, n_cls, device, relu, sep, pretrained=False):
    encoder = Encoder(n_cls, dataset, device, relu, sep)
    encoder = encoder.to(device)
    encoder = nn.DataParallel(encoder)
    if pretrained:
        ckpt = torch.load(f'runs/{dataset}_{arch}/bs_{baseline_}/encoder_tmp.pth')
        encoder.load_state_dict(ckpt['encoder'])
    return encoder

