import torch
import torch.nn.functional as F
from AIDomains.abstract_layers import Sequential, Flatten, Linear, ReLU, Conv2d, _BatchNorm, BatchNorm2d, BatchNorm1d, Normalization
from AIDomains.zonotope import HybridZonotope
from AIDomains.ai_util import construct_C
from AIDomains.wrapper import propagate_abs


from networks import get_network
from loaders import get_loaders
import argparse
from utils import seed_everything

def propagate_eps(input, net, C, abs:bool, relu_adjust=None):
    for i, layer in enumerate(net):
        if isinstance(layer, Normalization):
            input = input / layer.sigma # the weight matrix is all positive
        elif isinstance(layer, _BatchNorm):
            w = (layer.weight / torch.sqrt(layer.current_var + layer.eps)).view(layer.view_dim)
            if abs:
                input = input * w.abs()
            else:
                input = input * w
        elif isinstance(layer, Linear):
            if i != len(net) - 1:
                if abs:
                    input = F.linear(input, layer.weight.abs())
                else:
                    input = F.linear(input, layer.weight)
            else:
                # last linear, apply elision
                if abs:
                    elided_weight = torch.matmul(C, layer.weight).abs()
                else:
                    elided_weight = torch.matmul(C, layer.weight)
                input = torch.matmul(elided_weight, input.unsqueeze(-1)).squeeze(-1)
        elif isinstance(layer, Conv2d):
            if abs:
                input = F.conv2d(input, layer.weight.abs(), stride=layer.stride, padding=layer.padding, dilation=layer.dilation)
            else:
                input = F.conv2d(input, layer.weight, stride=layer.stride, padding=layer.padding, dilation=layer.dilation)
        elif isinstance(layer, Flatten):
            input = input.flatten(1, -1)
        elif isinstance(layer, ReLU):
            if relu_adjust is None:
                pass
            elif relu_adjust in ["local", "center"]:
                lb, ub = layer.bounds
                deactivation = ub < 0
                input[deactivation] = 0
            elif relu_adjust == "shrink":
                lb, ub = layer.bounds
                pre_size = ub - lb
                not_dead = ub > 0
                lb, ub = lb.clamp(min=0), ub.clamp(min=0)
                post_size = ub - lb
                input = input * (post_size.clamp(min=1e-8) / pre_size.clamp(min=1e-8)) * not_dead
            else:
                raise NotImplementedError(f"Unknown ReLU adjustment: {relu_adjust}")
        else:
            raise NotImplementedError(f"Unknown layer: {layer}")
    return input

def compute_tightness(net, batch_x, batch_y, eps, data_range=(0,1), num_classes:int=10, relu_adjust=None, detach_denom:bool=False, detach_num:bool=False, error_check:bool=False):
    '''
    Warning: this would destroy the previous grad and stored box bounds for the net
    '''
    input_eps = ((batch_x+eps).clamp(max=data_range[1]) - (batch_x-eps).clamp(min=data_range[0])) / 2
    num = input_eps.clone().detach()

    if batch_y is None:
        C = torch.eye(num_classes, device=batch_x.device).repeat(batch_x.shape[0], 1, 1)
    else:
        C = construct_C(num_classes, batch_y)

    # BN_layers = [layer for layer in net if isinstance(layer, _BatchNorm)]
    # original_stat = [layer.update_stat for layer in BN_layers]
    # for layer in BN_layers:
    #     layer.update_stat = False

    # set relu adjustment here
    # test status: correct. relu stat does not change inside this function after setting it below.
    net.reset_bounds()
    if relu_adjust == "local":
        # use the activation pattern at the original input as the adjustment
        with torch.no_grad():
            abs_input = HybridZonotope.construct_from_noise(batch_x, 0, domain="box")
            _ = net(abs_input)
    elif relu_adjust == "center":
        with torch.no_grad():
            center = ((batch_x+eps).clamp(max=data_range[1]) + (batch_x-eps).clamp(min=data_range[0])) / 2
            abs_input = HybridZonotope.construct_from_noise(center, 0, domain="box")
            _ = net(abs_input)
    elif relu_adjust == "shrink":
        # for unstable neurons, shrink the coefficient to ensure the same box size
        abs_input = HybridZonotope.construct_from_noise(batch_x, eps, domain="box")
        _ = net(abs_input)

    # infer signs of numerator here
    with torch.enable_grad():
        num.requires_grad = True
        out = propagate_eps(num, net, C, abs=False, relu_adjust=relu_adjust)
        net.zero_grad()
        signs = []
        out_dim = out.shape[-1]
        for i in range(out_dim):
            num.grad = None
            # sum over batch because we only want the grad w.r.t. the batch eps which are unconnected
            # thus, the grad of the sum is their individual grad
            # test status: correct; tested via comparing the individual backward with it
            out[..., i].sum().backward(retain_graph=True) 
            signs.append(num.grad.sign())

    # compute the numerator
    # test status: no error found; tested via checking whether all num are the largest and positive
    num = []
    for i, sign in enumerate(signs):
        num_one_dim = propagate_eps(input_eps * sign, net, C, abs=False, relu_adjust=relu_adjust)
        num.append(num_one_dim)
    num = torch.diagonal(torch.stack(num, dim=-1), dim1=-2, dim2=-1)

    # compute the denominator
    # test status: correct; tested via comparing direct propagation on a Deep Linear Network
    # Numerical Problem with BN: result has <0.001% inconsistency
    denom = propagate_eps(input_eps, net, C, abs=True, relu_adjust=relu_adjust)

    if detach_num:
        num = num.detach()
    if detach_denom:
        denom = denom.detach()

    # print("num:", num)
    # print("denom:", denom)

    # abs_input = HybridZonotope.construct_from_noise(batch_x, eps, domain="box")
    # abs_out = net(abs_input, C=C)
    # lb, ub = abs_out.concretize()
    # print("real:", (ub-lb)/2)


    # for layer, stat in zip(BN_layers, original_stat):
    #     layer.update_stat = stat

    net.reset_bounds()
    
    ratio = num.clamp(min=1e-8) / denom.clamp(min=1e-8)

    if error_check and not (ratio <= 1.01).all():
        # numerical errors could lead to this;
        # enable error_check=True if this is strict
        mask = ratio > 1
        print(num[mask])
        print(denom[mask])
        torch.save(net, "buggie.ckpt")
        raise RuntimeError("PI > 1 detected.")
    return ratio


if __name__ == "__main__":
    seed_everything(0)
    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    args.dataset = "mnist"
    args.train_batch = 128
    args.test_batch = 128
    args.grad_accu_batch = None
    args.frac_valid = None
    args.net = "cnn_3layer"
    args.init = "default"

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loaders, num_train, input_size, input_channel, n_class = get_loaders(args)
    train_loader, test_loader = loaders
    input_dim = (input_channel, input_size, input_size)

    net = get_network(args.net, args.dataset, device, init=args.init)
    net = Sequential.from_concrete_network(net, input_dim, disconnect=True)
    net.load_state_dict(torch.load("test_models/mnist/eps0.1/box_trained/cnn_3layer/init_fast/alpha5.0/fast_reg/model.ckpt"))

    print(net)

    eps = 0.3
    for x, y in test_loader:
        x = x.to(device)
        y = y.to(device)
        compute_tightness(net, x, y, eps, relu_adjust=None)