import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms as transforms
import datasets
import utils
import pandas as pd
import random
import os
from datetime import datetime
import numpy as np
from utils.evaluation import calculate_auc, find_threshold
from utils import basics

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.relu(out)

        return out

class DeepVAE(nn.Module):
    """
    a deeper VAE to take input size of 224*224
    """
    def __init__(self, opt, block = BasicBlock, layers = [2, 2, 2, 2]):
        super(DeepVAE, self).__init__()
        self.input_channel = opt['input_channel']
        self.output_channel = self.input_channel
        self.num_classes = opt['num_classes']
        self.zdim = opt['zdim']
        
        self.device = opt['device']
        #Encoder
        self.inplanes = 64
        self.conv1 = nn.Conv2d(self.input_channel, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.leakyrelu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, 512)
        self.fc1 = nn.Linear(512 , self.zdim * 2)
        
        #Decoder
        self.fc3 = nn.Linear(self.zdim, 500)
        self.fc4 = nn.Linear(500, 14*14*32)
        self.deconv1 = nn.ConvTranspose2d(32,64, kernel_size=3, stride =2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(64,32, kernel_size=3, stride =2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(32,16, kernel_size=3, stride =2, padding=1, output_padding=1)
        self.deconv4 = nn.ConvTranspose2d(16,self.output_channel, kernel_size=3, stride =2, padding=1, output_padding=1)
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        """
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def encode(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.leakyrelu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x  = self.fc(x)
        x  = self.leakyrelu(x)
        #print((x>0.000).sum())
        stats = self.fc1(x) #.squeeze()
        #print('self.fc1 version: ', self.fc1._version)
        mu, logvar = stats[:, : self.zdim].clone(), stats[:, self.zdim:].clone()
        return mu, logvar

    def reparametrize(self, mu, logvar):
        #std = logvar.mul(0.5).exp_()
        std = logvar.mul(0.5).exp()
        
        #eps = torch.FloatTensor(std.size()).normal_().to(self.device)
        eps = torch.FloatTensor(std.size()).normal().to(self.device)
        #eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        x = self.fc3(z)
        x = self.leakyrelu(x)
        x = self.fc4(x)
        x = self.leakyrelu(x)
        #print(x.size())
        #print((x>0.000).sum())
        x = x.view(-1,32,14,14)
        x = self.deconv1(x)
        x = self.leakyrelu(x)
        x = self.deconv2(x)
        x = self.leakyrelu(x)
        x = self.deconv3(x)
        x = self.leakyrelu(x)
        x = self.deconv4(x)
        x = self.sigmoid(x)

        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        #z = self.reparametrize(mu, logvar)
        #res = self.decode(z)
        #return res, mu, logvar
        return mu, logvar

"""Module for a VAE Convolutional Neural Network for CI-MNIST dataset
Parameters
----------
args: ArgumentParser
        Contains all model and shared input arguments
"""
class ConvVAE(nn.Module):
    # Input to ConvVAE is resized to 32 * 32 * 3, each pixel has 3 float values
    # between 0, 1 to represent each of RGB channels
    def __init__(self, args):
        """Initialized VAE CNN"""
        super(ConvVAE, self).__init__()

        self.z_dim = args['zdim']

        self.enc = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),
            #nn.LeakyReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            #nn.LeakyReLU(),
            nn.Conv2d(64, 64, 4, 2, 1),
            #nn.LeakyReLU(),
            Resize((-1, 1024)),
            nn.Linear(1024, 128),
            #nn.LeakyReLU(),
            nn.Linear(128, 2 * self.z_dim)
        )

        self.dec = nn.Sequential(
            nn.Linear(self.z_dim, 128),
            #nn.LeakyReLU(),
            nn.Linear(128, 1024),
            #nn.LeakyReLU(),
            Resize((-1, 64, 4, 4)),
            nn.ConvTranspose2d(64, 64, 4, 2, 1),
            #nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            #nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
        )
        self.weight_init()

    def weight_init(self, mode="normal"):
        """Initializes weights of VAE parameters"""
        for block in self._modules:
            for m in self._modules[block]:
                if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
                    torch.nn.init.xavier_uniform_(m.weight)

    def encode(self, x):
        """encodes input, returns statistics of the distribution"""
        stats = self.enc(x).squeeze()
        print(stats.shape)
        mu, logvar = stats[:, : self.z_dim], stats[:, self.z_dim :]
        return mu, logvar

    def decode(self, z):
        """decodes latent representation"""
        stats = self.dec(z)
        return stats


"""Module for a VAE MLP for Adult dataset
Parameters
----------
num_neurons: list, dtype: int
        Number of neurons in each layer
zdim: int
        length of mu, std of the latent representation
activ: string, default: "leakyrelu"
        Activation function
"""
class MLP(nn.Module):
    def __init__(self, num_neurons, zdim, activ="leakyrelu"):
        """Initialized VAE MLP"""
        super(MLP, self).__init__()
        self.num_neurons = num_neurons
        self.num_layers = len(self.num_neurons) - 1
        self.hiddens = nn.ModuleList(
            [
                nn.Linear(self.num_neurons[i], self.num_neurons[i + 1])
                for i in range(self.num_layers)
            ]
        )
        for hidden in self.hiddens:
            torch.nn.init.xavier_uniform_(hidden.weight)
        self.activ = activ
        self.zdim = zdim

    def forward(self, inputs, mode):
        """Computes forward pass for VAE's encoder, VAE's decoder, classifier, discriminator"""
        L = inputs
        for hidden in self.hiddens:
            L = F.leaky_relu(hidden(L))
        if mode == "encode":
            mu, logvar = L[:, : self.zdim], L[:, self.zdim :]
            return mu, logvar
        elif mode == "decode":
            return L.squeeze()
        elif mode == "discriminator":
            logits, probs = L, nn.Softmax(dim=1)(L)
            return logits, probs
        elif mode == "classify":
            return L
        else:
            raise Exception(
                "Wrong mode choose one of encoder/decoder/discriminator/classifier"
            )
        return

"""Module for FFVAE network
Parameters
----------
args: ArgumentParser
        Contains all model and shared input arguments
"""
class FFVAE(nn.Module):
    """Initializes FFVAE network: VAE encoder, MLP classifier, MLP discriminator"""
    def __init__(self, opt, wandb):
        super(FFVAE, self).__init__()
        self.input_channel = opt['input_channel']
        self.num_classes = opt['num_classes']
        self.gamma = opt['gamma']
        self.alpha = opt['alpha'] 
        self.zdim = opt['zdim']
        #self.sensattr = opt.sensattr
        self.device = opt['device']
        
        self.epoch = 0
        self.save_path = opt['save_folder']
        self.print_freq = opt['print_freq']
        self.init_lr = opt['optimizer_setting']['lr']
        #self.log_writer = SummaryWriter(os.path.join(self.save_path, 'logfile'))
        self.wandb = wandb
        self.set_data(opt)
        
        self.best_val_acc = 0.
        self.best_val_loss = float("inf")

        # VAE encoder
        self.conv_vae = DeepVAE(opt).to(self.device)
        #self.conv_vae = ConvVAE(opt)
            
        # MLP Discriminator
        self.adv_neurons = [opt['zdim']] + opt['adepth'] * [opt['awidths']] + [2]
        self.discriminator = MLP(self.adv_neurons, opt['zdim']).to(self.device)

        # MLP Classifier
        self.class_neurons = (
            [opt['zdim']] + opt['cdepth'] * [opt['cwidths']] + [self.num_classes]
        )
        self.classifier = MLP(self.class_neurons, opt['zdim']).to(self.device)

        # index for sensitive attribute
        self.n_sens = 1
        self.sens_idx = list(range(self.n_sens))
        self.nonsens_idx = [
            i for i in range(int(self.zdim / 2)) if i not in self.sens_idx
        ]
        self.count = 0
        self.batch_size = opt['batch_size']

        (
            self.optimizer_ffvae,
            self.optimizer_disc,
            self.optimizer_class,
        ) = self.get_optimizer()
        
        print(self.conv_vae)
        print(self.classifier)
        print(self.discriminator)

    @staticmethod
    def build_model(args):
        """ Builds FFVAE class """
        model = FFVAE(args)
        return model

    def vae_params(self):
        """Returns VAE parameters required for training VAE"""
        return list(self.conv_vae.parameters())

    def discriminator_params(self):
        """Returns discriminator parameters"""
        return list(self.discriminator.parameters())

    def classifier_params(self):
        """Returns classifier parameters"""
        return list(self.classifier.parameters())

    def get_optimizer(self):
        """Returns an optimizer for each network"""
        optimizer_ffvae = torch.optim.Adam(self.vae_params())
        optimizer_disc = torch.optim.Adam(self.discriminator_params())
        optimizer_class = torch.optim.Adam(self.classifier_params())
        return optimizer_ffvae, optimizer_disc, optimizer_class
    
    def set_data(self, opt):
        """Set up the dataloaders"""
        
        data_setting = opt['data_setting']

        # normalize according to ImageNet
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=mean, std=std)

        if data_setting['augment']:
            transform_train = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])

        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        
        g = torch.Generator()
        g.manual_seed(opt['random_seed'])
        def seed_worker(worker_id):
            np.random.seed(opt['random_seed'] + worker_id)
            random.seed(opt['random_seed'] + worker_id)

        image_path = data_setting['image_feature_path']
        train_meta = pd.read_csv(data_setting['train_meta_path'], nrows = 20000) #, nrows = 2000
        val_meta = pd.read_csv(data_setting['val_meta_path'], nrows = 2000)
        test_meta = pd.read_csv(data_setting['test_meta_path'])
        
        self.train_data = datasets.CheXpert(train_meta, image_path, transform_train)
        self.train_loader = torch.utils.data.DataLoader(
                                self.train_data, batch_size=opt['batch_size'],
                                shuffle=True, num_workers=4, worker_init_fn=seed_worker,
                                     generator=g)
        
        self.val_data = datasets.CheXpert(val_meta, image_path, transform_test)
        self.val_loader = torch.utils.data.DataLoader(
                              self.val_data, batch_size=opt['batch_size'],
                              shuffle=True, num_workers=4, worker_init_fn=seed_worker,
                                     generator=g)
        self.test_data = datasets.CheXpert(test_meta, image_path, transform_test)
        self.test_loader = torch.utils.data.DataLoader(
                               self.test_data, batch_size=opt['batch_size'],
                               shuffle=True, num_workers=4, worker_init_fn=seed_worker,
                                     generator=g)
    
    def log_result(self, name, result, step):
        self.log_writer.add_scalars(name, result, step)

    def forward(self, inputs, labels, attrs, mode="ffvae_train"):
        """Computes forward pass through encoder ,
            Computes backward pass on the target function"""
        # Make inputs between 0, 1
        #x = (inputs + 1) / 2
        
        x = inputs

        # encode: get q(z,b|x)
        _mu, _logvar = self.conv_vae.encode(x)

        # only non-sensitive dims of latent code modeled as Gaussian
        mu = _mu[:, self.nonsens_idx] #.clone()
        logvar = _logvar[:, self.nonsens_idx] #.clone()
        
        zb = torch.zeros_like(_mu) 
        std = (logvar / 2).exp()
        q_zIx = torch.distributions.Normal(mu, std)

        # the rest are 'b', deterministically modeled as logits of sens attrs a
        b_logits = _mu[:, self.sens_idx] #.clone()

        # draw reparameterized sample and fill in the code
        z = q_zIx.rsample()
        # reparametrization
        zb[:, self.sens_idx] = b_logits #.clone()
        zb[:, self.nonsens_idx] = z #.clone()

        # decode: get p(x|z,b)
        # xIz_params = self.decoder(zb, "decode")  # decoder yields distn params not preds
        
        xIz_params = self.conv_vae.decode(zb)
        xIz_params = torch.sigmoid(xIz_params)
      

        p_xIz = torch.distributions.Normal(loc=xIz_params, scale=1.0)
        # negative recon error per example
        logp_xIz = p_xIz.log_prob(x)  

        recon_term = logp_xIz.reshape(len(x), -1).sum(1)

        # prior: get p(z)
        p_z = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        # compute analytic KL from q(z|x) to p(z), then ELBO
        kl = torch.distributions.kl_divergence(q_zIx, p_z).sum(1)

        # vector el
        elbo = recon_term - kl  

        # decode: get p(a|b)
        clf_losses = [
            nn.BCEWithLogitsLoss()(_b_logit.to(self.device), _a_sens.to(self.device))
            for _b_logit, _a_sens in zip(
                b_logits.squeeze().t(), attrs.type(torch.FloatTensor).t()
            )
        ]

        # compute loss
        logits_joint, probs_joint = self.discriminator(zb, "discriminator")
        total_corr = (logits_joint[:, 0] - logits_joint[:, 1]) #.clone()
        
        ffvae_loss = (
            -1.0 * elbo.mean()
            + self.gamma * total_corr.mean()
            + self.alpha * torch.stack(clf_losses).mean()
        )
        
        # shuffling minibatch indexes of b0, b1, z
        z_fake = torch.zeros_like(zb)
        z_fake[:, 0] = zb[:, 0][torch.randperm(zb.shape[0])] #.clone()
        z_fake[:, 1:] = zb[:, 1:][torch.randperm(zb.shape[0])] #.clone()
        z_fake = z_fake.to(self.device).detach()

        # discriminator
        logits_joint_prime, probs_joint_prime = self.discriminator(
            z_fake, "discriminator"
        )
        ones = torch.ones(logits_joint_prime.shape[0], dtype=torch.long, device=self.device)
        zeros = torch.zeros(logits_joint.shape[0], dtype=torch.long, device=self.device)
        disc_loss = (
            0.5
            * (
                F.cross_entropy(logits_joint, zeros)
                + F.cross_entropy(logits_joint_prime, ones)
            ).mean()
        )
            
        encoded_x = _mu.detach() #.clone()

        # IMPORTANT: randomizing sensitive latent
        encoded_x[:, 0] = torch.randn_like(encoded_x[:, 0])

        pre_softmax = self.classifier(encoded_x, "classify")
        #logprobs = F.log_softmax(pre_softmax, dim=1)
        #class_loss = F.nll_loss(logprobs, labels)
        class_loss = F.binary_cross_entropy_with_logits(pre_softmax, labels)

        cost_dict = dict(
            ffvae_cost=ffvae_loss, disc_cost=disc_loss, main_cost=class_loss
        )

        # ffvae optimization
        
        if mode == "ffvae_train":
            """
            self.optimizer_ffvae.zero_grad()
            ffvae_loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(self.vae_params(), 5.0)
            self.optimizer_ffvae.step()
            

            self.optimizer_disc.zero_grad()
            disc_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.discriminator_params(), 5.0)
            self.optimizer_disc.step()
            """
            self.optimizer_ffvae.zero_grad()
            self.optimizer_disc.zero_grad()
            
            ffvae_loss.backward(retain_graph=True)
            disc_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(self.vae_params(), 5.0)
            self.optimizer_ffvae.step()

            torch.nn.utils.clip_grad_norm_(self.discriminator_params(), 5.0)
            self.optimizer_disc.step()
            

        # classifier optimization
        elif mode == "train":
            self.optimizer_class.zero_grad()
            class_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.classifier_params(), 5.0)
            self.optimizer_class.step()

        return pre_softmax, cost_dict
    
    
    def _train(self, loader):
        """Train the model for one epoch"""
        
        self.conv_vae.train()
        self.classifier.train()
        self.discriminator.train()
        
        train_loss_ffvae, train_loss_disc, train_loss_main = 0., 0., 0.
        total = 0
        correct = 0
        auc = 0.
        for i, (images, targets, sensitive_attr) in enumerate(loader):
            images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device)
            #print(images.shape, targets.shape, sensitive_attr.shape)
            if self.epoch < 10:
                outputs, loss_dict = self.forward(images, targets, sensitive_attr, mode = "ffvae_train")
            else:
                outputs, loss_dict = self.forward(images, targets, sensitive_attr, mode = "train")
             
            auc += calculate_auc(F.sigmoid(outputs[:, 0]).cpu().data.numpy(), targets[:, 0].cpu().data.numpy())
            
            self.log_result('Train iteration', {'ffvae loss': loss_dict['ffvae_cost'].item(), 'disc loss': loss_dict['disc_cost'].item(), 'main loss': loss_dict['main_cost'].item()},
                            len(loader)*self.epoch + i)
            train_loss_ffvae += loss_dict['ffvae_cost'].item()
            train_loss_disc += loss_dict['disc_cost'].item()
            train_loss_main += loss_dict['main_cost'].item()
            
            if self.print_freq and (i % self.print_freq == 0):
                print('Training epoch {}: [{}|{}], ffvae loss:{}, disc loss:{}, main loss:{}'.format(
                      self.epoch, i+1, len(loader), loss_dict['ffvae_cost'].item(), loss_dict['disc_cost'].item(), loss_dict['main_cost'].item()))
                
        auc = 100 * auc / (1 + (len(loader.dataset)// loader.batch_size))        
        self.log_result('Train epoch',{'ffvae loss': train_loss_ffvae, 'disc loss': train_loss_disc, 'main loss': train_loss_main, 'AUC': auc}, self.epoch)
        print('Training epoch {}: [{}|{}], AUC:{}'.format(
                      self.epoch, i+1, len(loader), auc))
        self.epoch += 1
        
    def _val(self, loader):
        """Compute model output on validation set"""
        
        self.conv_vae.eval()

        test_loss = 0
        auc = 0.
        with torch.no_grad():
            for i, (images, targets, sensitive_attr) in enumerate(loader):
                images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device)
                outputs, loss_dict = self.forward(images, targets, sensitive_attr, mode = "test")
                
                auc += calculate_auc(F.sigmoid(outputs[:, 0]).cpu().data.numpy(), targets[:, 0].cpu().data.numpy())
                
        
        auc = 100 * auc / (1 + (len(loader.dataset)// loader.batch_size))
        print('Validation epoch {}: validation loss:{}, AUC:{}'.format(
                      self.epoch, test_loss, auc))
        self.log_result('validation epoch', {'loss': loss_dict['ffvae_cost'], 'AUC': auc}, self.epoch)
        return loss_dict['ffvae_cost'], auc
    
    
    def train(self, epoch):
        """Train the model for one epoch, evaluate on validation set and 
        save the best model
        """
        
        start_time = datetime.now()
        self._train(self.train_loader)
        
        #basics.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'ckpt.pth'))
        
        val_loss, val_auc = self._val(self.val_loader)
        
        self.log_result('Val epoch', {'loss': val_loss/len(self.val_loader), 'AUC': val_auc},
                        self.epoch)
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            basics.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'best.pth'))
            print('saving best model in epoch ', epoch, '........')
        
        duration = datetime.now() - start_time
        print('Finish training epoch {}, Val AUC: {}, time used: {}'.format(self.epoch, val_auc, duration))
        

"""Function to resize input
Parameters
----------
size: tuple
        target size to be resized to
tensor: torch tensor
        input tensor
Returns
----------
Resized tensor
"""
class Resize(torch.nn.Module):
    def __init__(self, size):
        super(Resize, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)