import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50
import torch.backends.cudnn as cudnn
from archs.cifar_resnet import resnet as resnet_cifar
from datasets import get_normalize_layer,SimuDataset
from torch.nn.functional import interpolate
from torch.sparse import mm
from torch.nn import Softmax
import numpy as np
from train_utils import setup_seed
# resnet50 - the classic ResNet-50, sized for ImageNet
# cifar_resnet20 - a 20-layer residual network sized for CIFAR
# cifar_resnet110 - a 110-layer residual network sized for CIFAR
ARCHITECTURES = ["resnet50", "cifar_resnet20", "cifar_resnet110",'neural']

def get_architecture(arch: str, dataset: str, classes: int = 50) -> torch.nn.Module:
    """ Return a neural network (with random weights)

    :param arch: the architecture - should be in the ARCHITECTURES list above
    :param dataset: the dataset - should be in the datasets.DATASETS list
    :return: a Pytorch module
    """
    if arch == "resnet50" and dataset == "imagenet":
        if classes == 1000:
            model = torch.nn.DataParallel(resnet50(pretrained=True)).cuda()
        cudnn.benchmark = True
    elif dataset == "AWA":
        model = torch.nn.DataParallel(AWADNN(classes = classes)).cuda()
    elif dataset == "word50_character":
        model = torch.nn.DataParallel(MLP(n_class = classes)).cuda()
    elif dataset == "word50_word":
        model = torch.nn.DataParallel(MLP(n_class = classes)).cuda()
    elif arch == "cifar_resnet20":
        model = resnet_cifar(depth=20, num_classes=10).cuda()
    elif arch == "cifar_resnet110":
        model = resnet_cifar(depth=110, num_classes=10).cuda()
    normalize_layer = get_normalize_layer(dataset)
    return torch.nn.Sequential(normalize_layer, model)

class AWADNN(nn.Module):
    def __init__(self, classes = 50, model_type= 'res50', pretrained=True):
        super(AWADNN, self).__init__()
        if model_type == 'res50':
            self.model = resnet50(pretrained=pretrained)
        else:
            raise NotImplementedError()
        self.classes = classes
        self.model.fc = nn.Linear(self.model.fc.in_features, self.classes)
        
    def forward(self, x):
        x = self.model(x)
        return x
    
class MLP(nn.Module):
    def __init__(self, n_class = 26):
        super(MLP,self).__init__()
        if n_class == 26:
            self.dim = 28 * 28
            hidden_1 = 512
            hidden_2 = 512
            self.droput = nn.Dropout(0.2)
        # number of hidden nodes in each layer (512)
        elif n_class == 50:
            self.dim = 28 * 28 * 5
            hidden_1 = 1024
            hidden_2 = 1024
            self.droput = nn.Dropout(0.3)
        # linear layer (dim -> hidden_1)
        self.fc1 = nn.Linear(self.dim, hidden_1)
        # linear layer (n_hidden -> hidden_2)
        self.fc2 = nn.Linear(hidden_1, hidden_2)
        # linear layer (n_hidden -> n_class)
        self.fc3 = nn.Linear(hidden_2, n_class)
        # dropout layer (p=0.2)
        # dropout prevents overfitting of data
        
    def forward(self,x):
        # flatten image input
        x = x.view(-1, self.dim)
        # add hidden layer, with relu activation function
        x = F.relu(self.fc1(x))
        # add dropout layer
        x = self.droput(x)
         # add hidden layer, with relu activation function
        x = F.relu(self.fc2(x))
        # add dropout layer
        x = self.droput(x)
        # add output layer
        x = self.fc3(x)
        return x

class ReasonNN(nn.Module):
    def __init__(self, dataset = "AWA"):
        super(ReasonNN, self).__init__()
        self.C = None
        self.D = None
        self.B = torch.FloatTensor([])
        self.A = torch.FloatTensor([])
        self.dataset =dataset
        if self.dataset == "AWA":
            self.n_class = 28
            self.num_rv = 28 + 85 + 50
        elif "word50" in self.dataset:
            self.n_class = 50
            self.num_rv = 50 + 26 * 5
        self.weight = []
        self.index = None
        self.word50_gt_pred = torch.load('word50_gt_pred.pt')
        # the grounding predictor vector.
        self.attribute_matrix = torch.load( '../data/Animals_with_Attributes2/attribute_gt_pred.pt')
        self.hierarchy_matrix = torch.load( '../data/Animals_with_Attributes2/hierarchy_gt_pred.pt')
        self.build_graph()
        
    def forward(self, x):
        
        self.w = self.weight[self.index].unsqueeze(0)
#         self.w = self.weight.unsqueeze(0)
        inner = (torch.mm(self.C,x.T) + self.B).unsqueeze(1) + self.D.unsqueeze(2)
        inner = inner.permute(2,0,1)
#         score = -(self.W * nn.ReLU()(inner)).sum(1)
        score = -(self.w @ nn.ReLU()(inner)).squeeze(1)
        
#         self.w = self.weight[self.index].unsqueeze(0)
#         inner = (self.C @ x.T + self.B).unsqueeze(1) + self.D.unsqueeze(2)
#         inner = inner.permute(2,0,1)
#         score = -(self.w @ nn.ReLU()(inner)).squeeze(1)
        return score

    def set_main_sensor(self):
        add_A = torch.zeros((2 * self.n_class, self.num_rv + self.n_class))
        add_B = torch.zeros((2 * self.n_class, 1))
        for i in range(self.n_class):
            add_A[i*2,i] = 1
            add_A[i*2+1,i] = -1
            add_A[i*2,-(self.n_class-i)] = -1
            add_A[i*2+1,-(self.n_class-i)] = 1
        self.A = torch.cat((self.A,add_A),dim = 0)
        self.B = torch.cat((self.B,add_B),dim = 0)
        self.weight.extend([1.] * self.n_class)

    def add_lor_sensor(self, head = [], body = []):
        add_A = torch.zeros((1 ,self.num_rv + self.n_class))
        add_B = torch.zeros((1,1))
        for index in head:
            add_A[0, index] = 1.0
        for index in body:
            add_A[0, index] = -1.0 
        add_B[0] = 0
        self.A = torch.cat((self.A,add_A),dim = 0)
        self.B = torch.cat((self.B,add_B),dim = 0)
        self.weight.append(1.)

    def add_land_sensor(self, head = [], body = []):
        add_A = torch.zeros((1 ,self.num_rv + self.n_class))
        add_B = torch.zeros((1,1))
        for index in head:
            add_A[0, index] = 1.0 / len(head)
        for index in body:
            add_A[0, index] = -1.0 / len(body)
        add_B[0] = 0
        self.A = torch.cat((self.A,add_A),dim = 0)
        self.B = torch.cat((self.B,add_B),dim = 0)
        self.weight.append(1.)
        
    def add_lnot_sensor(self, head = [], body = []):
        add_A = torch.zeros((1 ,self.num_rv + self.n_class))
        add_B = torch.zeros((1,1))
        for index in head:
            add_A[0, index] = 1
        for index in body:
            add_A[0, index] = 1
        add_B[0] = 1 - len(head) - len(body)
        self.A = torch.cat((self.A,add_A),dim = 0)
        self.B = torch.cat((self.B,add_B),dim = 0)
        self.weight.append(1.)
        
    def build_graph(self):
        self.set_main_sensor()
        
        if self.dataset == "AWA":
            # attribute direction
            for index in range(len(self.attribute_matrix)): 
                head = [self.num_rv + index]
                for i,j in enumerate(self.attribute_matrix[index].tolist()):
                    body = [self.n_class+i]
                    if j == 0:
                        self.add_lnot_sensor(body,head)
    #                     continue
                    elif j == -1:
                        continue
                    elif j == 1:
                        self.add_lor_sensor(head,body)

            # hierarchy direction
            for index in range(len(self.hierarchy_matrix)): 
                body = [self.num_rv + index] 
                    
                for i,j in enumerate(self.hierarchy_matrix[index].tolist()):
                    head = [self.n_class + 85 + i]
                    if j == 0:
                        self.add_lnot_sensor(body,head)
    #                     continue
                    elif j == -1:
                        continue
                    elif j == 1:
                        self.add_lor_sensor(head,body)

        elif "word" in self.dataset:

            # attribute direction
            for index in range(len(self.word50_gt_pred)):     
                word_5 = self.word50_gt_pred[index]
                for q in range(5):
                    body = [self.n_class + word_5[q]]
                    self.add_land_sensor(head,body)
                    
                for q in range(4):
                    for w in range(q+1,5):
                        body = [self.n_class + word_5[q],self.n_class + word_5[w]]
                        self.add_land_sensor(head,body)

             # hierarchy direction
            for index in range(len(self.word50_gt_pred)): 
                word_5 = self.word50_gt_pred[index]
                for q in range(5):
                    x = [0,1,2,3,4]
                    x.remove(q)
                    head = [self.n_class + word_5[k] for k in x]
                    self.add_land_sensor(head,body)

                for q in range(4):
                    for w in range(q+1,5):
                        x = [0,1,2,3,4]
                        x.remove(q)
                        x.remove(w)
                        head = [self.n_class + word_5[k] for k in x]
                        self.add_land_sensor(head,body)


        self.weight = torch.nn.Parameter(torch.FloatTensor(self.weight), requires_grad=True)
        self.index = [i//2 for i in range(2*self.n_class)]
        self.index.extend([i for i in range(self.n_class,len(self.weight))])
#         self.C = torch.nn.Parameter(torch.FloatTensor(torch.nn.Linear(*self.A[:,:self.num_rv].T.shape).weight), requires_grad=True)
#         self.D = torch.nn.Parameter(torch.FloatTensor(torch.nn.Linear(*self.A[:,self.num_rv:].T.shape).weight), requires_grad=True)
#         self.B = torch.nn.Parameter(torch.FloatTensor(torch.nn.Linear(*self.B.T.shape).weight), requires_grad=True)
        
        self.C = torch.nn.Parameter(torch.FloatTensor(self.A[:,:self.num_rv]), requires_grad=False)
        self.D = torch.nn.Parameter(torch.FloatTensor(self.A[:,self.num_rv:]), requires_grad=False)
        self.B = torch.nn.Parameter(torch.FloatTensor(self.B), requires_grad=False)
        
        self.A = self.A
#         self.w = self.weight[self.index].unsqueeze(0)
#         self.w = self.weight.unsqueeze(0)
#         self.W = torch.nn.Parameter(torch.FloatTensor(torch.nn.Linear(*torch.ones((self.w.shape[1],self.n_class)).T.shape).weight), requires_grad=True)
#         self.W = torch.nn.Parameter(torch.ones((self.w.shape[1],self.n_class)), requires_grad=True)

        print("Total clauses:",len(self.weight))
