import os, sys
CODE_HOME = os.path.normpath(os.path.join(os.getcwd()))
sys.path.append(CODE_HOME)

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..model_abstract import Classifier_abstract

class Classify_MNIST(Classifier_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        d = 32
        self.encoder = nn.Sequential(
            nn.Conv2d(1, d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, d, kernel_size = 4, padding = 'same'),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, 2*d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 2*d, kernel_size = 4, padding = 'same'),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),
            
            nn.Flatten(),
            nn.Linear(49*2*d, self.z_dim)
            )
        
        self.decoder = nn.Sequential(
            nn.BatchNorm1d(self.z_dim),
            nn.ReLU(True),
            nn.Linear(self.z_dim, 10),
            )

        self.encoder_trainable = [self.encoder]
        self.decoder_trainable = [self.decoder]

class Classify_eYaleB(Classifier_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        self.y_dim = int(cfg['train_info']['y_dim'])
        d = 64
        self.encoder = nn.Sequential(
            nn.Conv2d(1, d, kernel_size = 5, stride = 2, padding = 2, bias = False),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, 2*d, kernel_size = 5, stride = 2, padding = 2, bias = False),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 4*d, kernel_size = 5, stride = 2, padding = 2, bias = False),
            nn.BatchNorm2d(4*d),
            nn.ReLU(True),

            nn.Conv2d(4*d, 8*d, kernel_size = 3, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(8*d),
            nn.ReLU(True),
            
            nn.Conv2d(8*d, 16*d, kernel_size = 3, stride = 2, padding = 1, bias = False),
            nn.BatchNorm2d(16*d),
            nn.ReLU(True),

            nn.Flatten(),
            nn.Linear(16*16*d, self.z_dim),
            nn.BatchNorm1d(self.z_dim),
            )
        
        self.decoder = nn.Sequential(
            nn.ReLU(True),
            nn.Linear(self.z_dim, self.y_dim),
            )

        self.encoder_trainable = [self.encoder]
        self.decoder_trainable = [self.decoder]

class Classify_vggface2(Classifier_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        self.class_no = cfg['train_info']['num_classes']
        self.lamb = float(cfg['train_info']['lambda'])
        self.attr = torch.Tensor([1,0,0,0,0,1,1,1,1,1,1]) # Male, Longhair, Mustache, Hat, Eyeglass, Sunglass, mouth open
        self.attr_dim = int(self.attr.sum().item())
        d = 128
        self.encoder = nn.Sequential(
            nn.Conv2d(3, d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, 2*d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 4*d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(4*d),
            nn.ReLU(True),

            nn.Conv2d(4*d, 8*d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(8*d),
            nn.ReLU(True),
            
            nn.Flatten(),
            )
        
        self.encoder_main = nn.Sequential(
            nn.Linear(64*8*d, self.z_dim),
            nn.BatchNorm1d(self.z_dim),
        )

        self.encoder_sub = nn.Sequential(
            nn.Linear(64*8*d, self.attr_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.ReLU(True),
            nn.Linear(self.z_dim, self.class_no),
            )

        init_params(self.encoder)
        init_params(self.decoder)

        self.encoder_trainable = [self.encoder, self.encoder_main, self.encoder_sub]
        self.decoder_trainable = [self.decoder]
    
    def encode(self, x):
        zz = self.encoder(x)
        return self.encoder_main(zz), self.encoder_sub(zz)

    def forward(self, x):
        zz, _ = self.encode(x)
        return self.decode(zz)

    def _get_losses(self, batch):
        x1, y1 = batch['main']  # When batch returns both image and label_attr
        mm, _ = self.encode(x1)
        p1 = self.decode(mm)

        if len(y1.shape) == 2:
            if y1.shape[1] == 1:
                y1 = y1.squeeze(1)

        x2, y2, s = batch['sub']
        m2, ps = self.encode(x2)
        p2 = self.decode(m2)
        if len(y2.shape) == 2:
            if y2.shape[1] == 1:
                y2 = y2.squeeze(1)

        celoss = (F.cross_entropy(p1, y1, reduction = 'sum') + F.cross_entropy(p2, y2, reduction = 'sum'))/(len(y1) + len(y2))
        celoss_sub = F.binary_cross_entropy_with_logits(ps, s[:,self.attr.type_as(s)].type_as(ps))

        return celoss, celoss_sub, (p1.max(dim = 1).indices == y1).sum().item()/len(x1)

    def training_step(self, batch, batch_idx):
        loss, loss_sub, acc= self._get_losses(batch)
        self.log("CEloss", loss, prog_bar=True, logger = False)
        self.log("CEloss_sub", loss_sub, prog_bar=True, logger = False)
        self.log("acc", acc, prog_bar=True, logger = False)

        self.log("train/CEloss", loss, on_step = False, on_epoch = True)
        self.log("train/CEloss_sub", loss_sub, on_step = False, on_epoch = True)
        self.log("train/acc", acc, on_step = False, on_epoch = True)

        return loss + self.lamb * loss_sub