#!/usr/bin/env python
# coding: utf-8

# In[1]:

from __future__ import print_function

import sys
sys.settrace

import os

# os.environ["CUDA_VISIBLE_DEVICES"]="0"
# os.environ['TQDM_MININTERVAL']='1000000000000000000000000000000000000000000000000000'

import tqdm

import unittest, torch
import torch.nn.functional as F

from random import seed, sample

from torchvision import datasets
from torchvision.transforms import ToTensor

from tql import Database, Query, Table
from tql.sqrl import SQRL

import time

# get_ipython().run_line_magic('load_ext', 'autoreload')
# get_ipython().run_line_magic('autoreload', '2')


test_data = datasets.MNIST(
        root = 'data',
        train = False,
        transform = ToTensor(),
        download = True,
    )
train_data = datasets.MNIST(
        root = 'data',
        train = True,
        transform = ToTensor(),
        download = True,
    )

db = Database("mnist")
db.register_dataset(test_data, "test")
db.register_dataset(train_data, "train")


# In[2]:


class row_sum_batch(torch.nn.Module):
    def __init__(self, row):
        super().__init__()
        self.row = row
        
    def forward(self, image_tensor, label):
        return image_tensor[:,0,self.row].sum(dim=1)



# In[3]:


print(test_data.targets.shape)


# In[4]:


q = Query('satisfies_rule_sum10', base='train')\
        .project(lambda imgs, labels: zip(imgs, labels, row_sum_batch(10)(imgs, labels)), batch_size=256)\
        .project(lambda imgs, label, sum_row_10: zip(imgs, label, sum_row_10, torch.logical_and(label == 5.0, torch.logical_and(0.7984313845634461 <= sum_row_10, sum_row_10 <= 15.972551059722871))), batch_size=256)


# In[5]:


res=q(db, disable=True)


# In[6]:


import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


# In[7]:


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output



# In[8]:


print("Is CUDA enabled?",torch.cuda.is_available())

model = Net().to('cuda')
model = nn.DataParallel(model)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)


# In[9]:


#model.load_state_dict(torch.load("mnist_cnn.pt"))


# In[10]:


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            # if args.dry_run:
            #     break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    i=0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            i+=1
            if i%10 == 0:
                print(i, test_loss)

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


# In[11]:


device='cuda'


# In[12]:


transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


# In[13]:


batch_size=256
train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': batch_size}
cuda_kwargs = {'pin_memory': True,
               'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)


# In[14]:


dataset2.data.shape


# In[15]:


train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)


# In[16]:


test(model, device, test_loader)


# In[17]:


model


# In[18]:


class row_sum_batch(torch.nn.Module):
    def __init__(self, row):
        super().__init__()
        self.row = row
        
    def forward(self, image_tensor, label):
        #print(image_tensor.shape, label.shape)
        return image_tensor[:,self.row].sum(dim=1)

def train_with_rules(model, device, train_loader, optimizer, epoch, rule_lambda=0.01, rule_only=False):
    model.train()
    i=0
   
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        ########################
        pred = output.argmax(dim=1, keepdim=True)
        db = Database("mnist_{}".format(i))
        db.register_dataset(data, "mnist_data_{}".format(i), disable=True)
        db.register_dataset(pred, "mnist_preds_{}".format(i), disable=True)
        db.register_dataset(output.reshape((len(pred), 1, 10)), "mnist_ops_{}".format(i), disable=True)
        
        q = Query('satisfies_rule_sum10', "mnist_data_{}".format(i)).join("mnist_preds_{}".format(i))\
                .join("mnist_ops_{}".format(i), key = lambda idx, *row: idx[-1])
        
        def foo(imgs, label, outputs, sum_row_10):
            return zip(imgs, label, sum_row_10,
                    (torch.logical_not(torch.logical_and(label == 5.0,
                    torch.logical_and(0.7984313845634461 <= sum_row_10,
                                      sum_row_10 <= 15.972551059722871))).reshape(len(sum_row_10), 1) * F.softmax(outputs, dim=0).gather(1, label.reshape((len(label), 1)))))

        q = q.project(lambda imgs, labels, outputs: zip(imgs, labels, outputs, row_sum_batch(10)(imgs, labels)), batch_size=256)
        q = q.project(foo, batch_size=256)
        res_tab_0 = q(db, disable=True)
        res_tab = res_tab_0.project(lambda *row: row[3], disable=True)
        
        res=torch.tensor(res_tab.rows, dtype=torch.float, requires_grad=True)
        rule_loss= torch.sum(res)
        ########################
        if rule_only:
            # print(res_tab_0.rows[0])
            # print(res_tab.rows[0])
            # print(res)
            # print(rule_loss)
            loss = rule_loss
        else:
            loss = F.nll_loss(output, target) + rule_lambda*rule_loss

        # print(rule_loss.requires_grad, loss.requires_grad)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tRule Loss: {:.3f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), rule_loss.item()))
        

def test_with_rules(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    i=0
    rule_loss = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            # register
            
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            #print(data.shape, output.shape, pred.shape)
            # print(i, "registering")
            db = Database("mnist_{}".format(i))
            db.register_dataset(data, "mnist_data_{}".format(i), disable=True)
            db.register_dataset(pred, "mnist_preds_{}".format(i), disable=True)
            db.register_dataset(output.reshape((len(pred), 1, 10)), "mnist_ops_{}".format(i), disable=True)
            
            # q = Query('satisfies_rule_sum10', "mnist_data_{}".format(i)).join("mnist_preds_{}".format(i))\
            #     .project(lambda imgs, labels: zip(imgs, labels, row_sum_batch(10)(imgs, labels)), batch_size=256)\
            #     .project(lambda imgs, label, sum_row_10: zip(imgs, label, sum_row_10, torch.logical_and(label == 5.0, torch.logical_and(0.7984313845634461 <= sum_row_10, sum_row_10 <= 15.972551059722871))), batch_size=256)

            q = Query('satisfies_rule_sum10', "mnist_data_{}".format(i)).join("mnist_preds_{}".format(i))\
                .join("mnist_ops_{}".format(i), key = lambda idx, *row: idx[-1])
        
            def foo(imgs, label, outputs, sum_row_10):
                return zip(imgs, label, sum_row_10,
                        (torch.logical_not(torch.logical_and(label == 5.0,
                        torch.logical_and(0.7984313845634461 <= sum_row_10,
                                        sum_row_10 <= 15.972551059722871))).reshape(len(sum_row_10), 1) * F.softmax(outputs, dim=0).gather(1, label.reshape((len(label), 1)))))

            q = q.project(lambda imgs, labels, outputs: zip(imgs, labels, outputs, row_sum_batch(10)(imgs, labels)), batch_size=256)
            q = q.project(foo, batch_size=256)
            res = q(db, disable=True).project(lambda *row: row[3], disable=True)

            res=q(db, disable=True)
            s=sum(res.project(lambda *row: row[3], disable=True))
            #print("rules loss", s) 
            rule_loss+=s.item()
            i+=1            
            #print(i, test_loss, rule_loss)

    test_loss /= len(test_loader.dataset)
    rule_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Average rule loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, rule_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


# In[19]:


test_with_rules(model, device, test_loader)


# In[ ]:


# rules + original loss
model = Net().to('cuda')
model = nn.DataParallel(model)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1, 5):
    train_with_rules(model, device, train_loader, optimizer, epoch, rule_lambda=0.01)
    test_with_rules(model, device, test_loader)
    scheduler.step()


# In[ ]:


model = Net().to('cuda')
model = nn.DataParallel(model)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
# using only rules
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1, 5):
    train_with_rules(model, device, train_loader, optimizer, epoch, rule_lambda=1.0, rule_only=True)
    test_with_rules(model, device, test_loader)
    scheduler.step()


# In[ ]:




