import os
import sys
import time
import math

import numpy as np
import torch

import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable

from model_defs import Flatten

from bound_layers import *

import random


# separate handling for manifold mixup
class BoundIBPLargeModel(nn.Module):
    
    def __init__(self, in_ch, in_dim, linear_size=512, num_classes=10):
        super(BoundIBPLargeModel, self).__init__()
        
        self.conv1 = BoundConv2d(in_ch, 64, 3, stride=1, padding=1)
        self.relu1 = BoundReLU(self.conv1)
        self.conv2 = BoundConv2d(64, 64, 3, stride=1, padding=1)
        self.relu2 = BoundReLU(self.conv2)
        self.conv3 = BoundConv2d(64, 128, 3, stride=2, padding=1)
        self.relu3 = BoundReLU(self.conv3)
        self.conv4 = BoundConv2d(128, 128, 3, stride=1, padding=1)
        self.relu4 = BoundReLU(self.conv4)
        self.conv5 = BoundConv2d(128, 128, 3, stride=1, padding=1)
        self.relu5 = BoundReLU(self.conv5)
        
        self.flatten = BoundFlatten()
        self.linear1 = BoundLinear((in_dim//2) * (in_dim//2) * 128, linear_size)
        self.relu6 = BoundReLU(self.linear1)
        self.linear2 = BoundLinear(linear_size,10)

        self.num_classes = num_classes
        
    def __call__(self, *input, **kwargs):
        
        if "method_opt" in kwargs:
            opt = kwargs["method_opt"]
            kwargs.pop("method_opt")
        else:
            raise ValueError("Please specify the 'method_opt' as the last argument.")
        if "disable_multi_gpu" in kwargs:
            kwargs.pop("disable_multi_gpu")
        if opt == "full_backward_range":
            return self.full_backward_range(*input, **kwargs)
        elif opt == "backward_range":
            return self.backward_range(*input, **kwargs)
        elif opt == "interval_range": 
            return self.interval_range(*input, **kwargs)
        else:
            return super(BoundIBPLargeModel, self).__call__(*input, **kwargs)

    def forward(self, x, target= None, mixup=False, mixup_hidden=False, mixup_alpha=None):

        if mixup_hidden:
            layer_mix = random.randint(0,4)
        elif mixup:
            layer_mix = 0
        else:
            layer_mix = None   
        
        out = x
        
        if mixup_alpha is not None:
            lam = get_lambda(mixup_alpha)
            lam = torch.from_numpy(np.array([lam]).astype('float32')).cuda()
            lam = Variable(lam)
        
        if target is not None :
            target_reweighted = to_one_hot(target,self.num_classes)
            
        if layer_mix == 0:
                out, target_reweighted = mixup_process(out, target_reweighted, lam=lam)
        out = self.relu1(self.conv1(out))
        
        if layer_mix == 1:
                out, target_reweighted = mixup_process(out, target_reweighted, lam=lam)
        out = self.relu2(self.conv2(out))
        if layer_mix == 2:
                out, target_reweighted = mixup_process(out, target_reweighted, lam=lam)
        out = self.relu3(self.conv3(out))
        if layer_mix == 3:
                out, target_reweighted = mixup_process(out, target_reweighted, lam=lam)
        out = self.relu4(self.conv4(out))
        if layer_mix == 4:
                out, target_reweighted = mixup_process(out, target_reweighted, lam=lam)
        out = self.relu5(self.conv5(out))
        
        out = self.flatten(out)
        out = self.linear1(out)
        out = self.relu6(out)
        out = self.linear2(out)
        
        if target is not None:
            return out, target_reweighted
        else: 
            return out
            
    def interval_range(self, norm=np.inf, x_U=None, x_L=None, eps=None, C=None):
        losses = 0
        unstable = 0
        dead = 0
        alive = 0
        h_U = x_U
        h_L = x_L
        for i, module in enumerate(list(self._modules.values())[:-1]):
            # all internal layers should have Linf norm, except for the first layer
            norm, h_U, h_L, loss, uns, d, a = module.interval_propagate(norm, h_U, h_L, eps)
            # this is some stability loss used for initial experiments, not used in CROWN-IBP as it is not very effective
            losses += loss
            unstable += uns
            dead += d
            alive += a
        # last layer has C to merge
        norm, h_U, h_L, loss, uns, d, a = list(self._modules.values())[-1].interval_propagate(norm, h_U, h_L, eps, C)
        losses += loss
        unstable += uns
        dead += d
        alive += a
        return h_U, h_L, losses, unstable, dead, alive
            

def mixup_data(x, y, alpha=1.0, use_cuda=True):

    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam
    
    
def mixup_data_fixed_lam(x, y, lam, use_cuda=True):
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)
    
    mixed_x = lam * x + (1 - lam) * x[index,:]
    return mixed_x, y
    

def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
    
    
# BCE implementation
    
def mixup_process(out, target_reweighted, lam):
    target_reweighted = target_reweighted.cuda()
    indices = torch.randperm(out.size(0)).cuda()
    out = out*lam + out[indices]*(1-lam)
    target_shuffled_onehot = target_reweighted[indices]
    target_reweighted = target_reweighted * lam + target_shuffled_onehot * (1 - lam)
    return out, target_reweighted
    
def to_one_hot(inp,num_classes):
    y_onehot = torch.FloatTensor(inp.size(0), num_classes)
    y_onehot.zero_()

    y_onehot.scatter_(1, inp.unsqueeze(1).data.cpu(), 1)
    
    #return Variable(y_onehot.cuda(),requires_grad=False)
    return y_onehot
    
def get_lambda(alpha=1.0):
    '''Return lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    return lam
