from auto_LiRPA import PerturbationLpNorm, BoundedParameter, BoundedModule
from auto_LiRPA.operators.convex_concave import BoundExp
from torch.nn.utils import clip_grad_norm_
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch

GC_SCALE=10

def define_bounds(model, norm=2):
    for p in model.modules():
        #TODO check/fix batch norm later
        if isinstance(p, nn.BatchNorm2d):
            continue
        if hasattr(p, 'weight'):
            ptb = PerturbationLpNorm(norm=norm, eps=np.linalg.norm(p.weight.data.view(-1).cpu().numpy(), ord=norm))
            p.weight = BoundedParameter(p.weight.data, ptb)
        if hasattr(p, 'bias') and p.bias is not None:
            ptb = PerturbationLpNorm(norm=norm, eps=np.linalg.norm(p.bias.data.view(-1).cpu().numpy(), ord=norm))
            p.bias = BoundedParameter(p.bias.data, ptb)

def set_eps(model, eta, norm=2):
    if not isinstance(eta, list):
        for p in model.parameters():
            if isinstance(p, BoundedParameter):
                p.ptb.set_eps(eta * torch.norm(p.data.view(-1), p=norm))
    else:
        for i, p in enumerate(model.parameters()):
            if isinstance(p, BoundedParameter):
                #TODO this is a temp fix
                p.ptb.set_eps(eta[i % 2] * torch.norm(p.data.view(-1), p=norm))

def clip_model_grads(model, eta, norm=2):
    if not isinstance(eta, list):
        for p in model.parameters():
            if isinstance(p, BoundedParameter):
                clip_grad_norm_(p, GC_SCALE * eta * torch.norm(p.view(-1), p=norm), norm_type=2)
    else:
        for i, p in enumerate(model.parameters()):
            if isinstance(p, BoundedParameter):
                clip_grad_norm_(p, GC_SCALE * eta[i] * torch.norm(p.view(-1), p=norm), norm_type=2)


class ModelWrapper(BoundedModule):
    def __init__(self, model, input, args):
        super(ModelWrapper, self).__init__(model, input, device=args.device, bound_opts={'matmul': 'economic'})
        self.num_classes = 0
        self.params = list(model.parameters())

def get_exp_module(bounded_module):
    for _, node in bounded_module.named_modules():
        # Find the Exp neuron in computational graph
        if isinstance(node, BoundExp):
            return node
    return None

class ModelWrapperLoss(nn.Module):
    def __init__(self, model):
        super(ModelWrapperLoss, self).__init__()
        self.model = model

    def forward(self, x, labels, n_classes):
        # print(x.size(), labels.size(), n_classes)
        # for some reason the labels/n_classes is perturbed
        y = self.model(x)[:,:n_classes.item()]
        logits = y - torch.gather(y, dim=-1, index=labels.unsqueeze(-1))
        return torch.exp(logits).sum(dim=-1)

class mlp_2layer(nn.Module):
    def __init__(self, in_dim=28*28*1, out_dim=10, h_dim=100):
        super().__init__()
        self.in_dim = in_dim
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, out_dim)
        self.num_classes = 0
        self.params = [self.fc1, self.fc2]
    
    def forward(self, x):
        x = x.view(-1, self.in_dim)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class mlp_3layer(nn.Module):
    def __init__(self, in_dim=28 * 28 * 1,  h_dim=100, out_dim=10):
        super().__init__()
        self.in_dim = in_dim
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        self.fc3 = nn.Linear(h_dim, out_dim)
        self.num_classes = 0
        self.params = [self.fc1, self.fc2, self.fc3]
        

    # def define_bounds(self, norm=2):
    #     for p in self.params:
    #         ptb = PerturbationLpNorm(norm=norm, eps=np.linalg.norm(p.weight.data.view(-1).cpu().numpy(), ord=norm))
    #         p.weight = BoundedParameter(p.weight.data, ptb)
    #         ptb = PerturbationLpNorm(norm=norm, eps=np.linalg.norm(p.bias.data.view(-1).cpu().numpy(), ord=norm))
    #         p.bias = BoundedParameter(p.bias.data, ptb)

    def forward(self, x, returnt='out'):
        x = x.view(-1, self.in_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        if returnt == 'features':
            return x
        x = self.fc3(x)
        return x


class mlp_conv(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(3, 32, 5, 1)
        self.conv2 = nn.Conv2d(32, 64, 5, 1)
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class mlp_3layer_weight_perturb(BoundedModule):
    def __init__(self, model, input, args):
        super(mlp_3layer_weight_perturb, self).__init__(model, input, device=args.device, bound_opts={'matmul': 'economic'})
        self.num_classes = 0
        self.params = model.params
        
        
    def update_eps(self, eta, norm=2):
        if not isinstance(eta, list):
            for p in self.params:
                p.weight.ptb.set_eps(eta * torch.norm(p.weight.data.view(-1), p=norm))
                p.bias.ptb.set_eps(eta * torch.norm(p.bias.data.view(-1), p=norm))
        else:
            for i, p in self.params:
                p.weight.ptb.set_eps(eta[2*i] * torch.norm(p.weight.data.view(-1), p=norm))
                p.bias.ptb.set_eps(eta[2*i + 1] * torch.norm(p.bias.data.view(-1), p=norm))
                

    def bound_parameters(self, d_list, pert_weight=True, pert_bias=True, norm=float('inf')):
        for i,p in enumerate(self.params):
            ptb = PerturbationLpNorm(norm=norm, eps=np.linalg.norm(d_list[2*i], ord=norm))
            p.weight = BoundedParameter(p.weight.data, ptb)
            ptb = PerturbationLpNorm(norm=norm, eps=np.linalg.norm(d_list[2*i + 1], ord=norm))
            p.bias = BoundedParameter(p.bias.data, ptb)


class ModelWrapperLossViT(nn.Module):
    def __init__(self, model):
        super(ModelWrapperLossViT, self).__init__()
        self.model = model

    def forward(self, x, labels, n_classes, prompts=None):
        # print(x.size(), labels.size(), n_classes)
        # for some reason the labels/n_classes is perturbed
        y = self.model(x, prompts)[:,:n_classes.item()]
        logits = y - torch.gather(y, dim=-1, index=labels.unsqueeze(-1))
        return torch.exp(logits).sum(dim=-1)

class Prompt(nn.Module):
    def __init__(self, embed_dim, prompt_length, n_layers=12):
        super(Prompt, self).__init__()
        self.prompt_length = prompt_length
        self.prompts = nn.ParameterList(torch.randn(prompt_length, embed_dim) for _ in range(n_layers))


class ViTEmbed(nn.Module):
    def __init__(self, input_channels, embed_dim):
        super(ViTEmbed, self).__init__()
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=16, stride=16)

    def forward(self, x):
        x = self.conv(x)  
        x = x.flatten(2)
        x = x.permute(0, 2, 1) # Change shape to (batch_size, num_patches, embed_dim)
        return x

class ViT(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, num_classes):
        super(ViT, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.embedder = ViTEmbed(3, embed_dim)  # Assuming input channels are 3 (RGB images)
        self.cls_token = nn.Parameter(data=torch.randn(embed_dim))
        
        # Define the transformer layers here
        self.transformer_layers = nn.ModuleList([
            ViTLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x, prompts=None):
        # Implement the forward pass for the ViT model
        x = self.embedder(x)
        
        cls_tok = self.cls_token.expand((x.size(0), 1, x.size(2)))
        x = torch.concat((x, cls_tok), dim=1)
        for i, layer in enumerate(self.transformer_layers):
            if prompts is not None:
                p = prompts[i, :, :]
            else:
                p = None
            x = layer(x, prompt=p)
        # print(x.shape)
        x = self.fc(x[:, -1, :])
        return x

class ViTLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(ViTLayer, self).__init__()
        self.attention = MultiheadAttention(embed_dim, num_heads)
        self.linear1 = nn.Linear(embed_dim, embed_dim * 4)
        self.linear2 = nn.Linear(embed_dim * 4, embed_dim)
        self.norm1 = nn.BatchNorm1d(embed_dim)
        self.norm2 = nn.BatchNorm1d(embed_dim)

    def forward(self, x, prompt=None):
        attn_output = self.attention(x, prompt=prompt)
        x = (x + attn_output).transpose(1, 2)
        x = self.norm1((x)).transpose(1, 2)
        ff_output = F.relu(self.linear1(x))
        ff_output = self.linear2(ff_output)
        x = (x + ff_output).transpose(1, 2)
        x = self.norm2(x).transpose(1, 2)
        return x

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        # assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scaling = self.head_dim ** -0.5
        self.softmax = nn.Softmax(dim=-1)

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, prompt=None):
        B, N, C = x.shape
        if prompt is not None:
            p = prompt.expand((B, -1, C))  # Ensure prompt has the same batch size
            q, k, v = self.q_proj(x), self.k_proj(torch.concat((p, x), dim=1)), self.v_proj(torch.concat((p, x), dim=1))
            q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            k = k.reshape(B, N + p.size(1), self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            v = v.reshape(B, N + p.size(1), self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        else:
            q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
            q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
            k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 
            v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        # qkv = self.qkv_proj(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]

        attn_weights = (q @ k.transpose(-2, -1)) * self.scaling
        attn_weights = self.softmax(attn_weights)

        attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, N, C)
        return self.out_proj(attn_output)