import os
import torch
import math
from random import randrange
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import numpy as np
from CNNDT_Architecture import DTNet, DTNet_4
from blocks import BasicBlock2D as BasicBlock
from tqdm import tqdm
from torch.optim import SGD, Adam, AdamW
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
from torch.optim import Optimizer
import logging
import json
from collections import OrderedDict
from almost_unique_id import generate_id
import argparse

parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("--logdir", default="default-log-dir", type=str,
                    help="location where log of experiments will be stored")
parser.add_argument("--exp", default="default", type=str, help="name of experiments")
parser.add_argument("--epochs", default=10, type=int, help="size of batch of input to use")
parser.add_argument("--max_iters", default=40, type=int, help="number of iterations in DT model when training.")
parser.add_argument("--batch_size", default=5, type=int, help="batch_size")
parser.add_argument("--num_samples", default=1000, type=int, help="num samples in dataset")
parser.add_argument("--best", action="store_true", help="plot best model")
parser.add_argument("--type_2", action="store_true", help="Other idea model")
parser.add_argument("--in_channels", default=3, type=int, help="num channels")
parser.add_argument("--width", default=512, type=int, help="num channels")
parser.add_argument("--big_run", action="store_true", help="Test on all data model")
FLAGS = parser.parse_args() 
## HYPERPARAMS ##
folder_name = FLAGS.logdir
if FLAGS.exp == "default":
    run_id = generate_id()
else:
    run_id = FLAGS.exp
id = "outputs/"+folder_name +"/"+ run_id
os.makedirs(id, exist_ok=True)
train_batch_size = FLAGS.batch_size
epochs = FLAGS.epochs
max_iters = FLAGS.max_iters
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("testing:")
print(f"run id = {id}")
print(f"max iters  = {max_iters}")
print(f"num epochs = {epochs}")
print(f"batch size = {train_batch_size}")

class CreateDataset(torch.utils.data.Dataset):
    def __init__(self, image_data, count_data, transform=None):
        self.image_data = torch.from_numpy(image_data).to(torch.float)
        self.count_data = torch.from_numpy(count_data).to(torch.float)

    def __getitem__(self, index):
        image = self.image_data[index]
        count = self.count_data[index]
        return image, count

    def __len__(self):
        return len(self.image_data)

vals = {0:[1,0,0], 1:[0,1,0], 2:[0,0,1]}
def get_data(problem_file, label_file):
    X = []
    with open(problem_file, 'r') as f:
        problem = []
        for line in f:
            if line != '\n':
                problem.append([vals[int(i)] for i in line[:-1].split()])
            else:
                X.append(problem)
                problem = []
    Y = []
    with open(label_file, 'r') as f:
        label = []
        for line in f:
            if line != '\n':
                label.append([vals[int(i)] for i in line[:-1].split()])
            else:
                Y.append(label)
                label = []
    return X,Y

training_data, training_labels = get_data('variable_training_instances.txt', 'variable_training_instances_labels.txt')
validation_data, validation_labels = get_data('variable_validation_instances.txt', 'variable_validation_instances_labels.txt')

training_data = np.array(training_data)[:FLAGS.num_samples]
training_labels = np.array(training_labels)[:FLAGS.num_samples]

validation_data = np.array(validation_data)[:FLAGS.num_samples]
validation_labels = np.array(validation_labels)[:FLAGS.num_samples]

train_dataset = CreateDataset(np.transpose(training_data,(0,3,1,2)),np.transpose(training_labels,(0,3,1,2)))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=train_batch_size,
                                            shuffle=True, drop_last=True)

validatation_dataset = CreateDataset(np.transpose(validation_data,(0,3,1,2)), np.transpose(validation_labels,(0,3,1,2)))
validation_loader = torch.utils.data.DataLoader(dataset=validatation_dataset,
                                            batch_size=train_batch_size,
                                            shuffle=True, drop_last=True)

test_data_11, test_labels_11 = get_data('11x11_test.txt', '11x11_test_labels.txt')
test_data_11 = np.array(test_data_11)[:FLAGS.num_samples]
test_labels_11 = np.array(test_labels_11)[:FLAGS.num_samples]

test_data_13, test_labels_13 = get_data('13x13_test.txt', '13x13_test_labels.txt')
test_data_13 = np.array(test_data_13)[:FLAGS.num_samples]
test_labels_13 = np.array(test_labels_13)[:FLAGS.num_samples]

if FLAGS.big_run:
    test_data_12, test_labels_12 = get_data('12x12_test.txt', '12x12_test_labels.txt')
    test_data_12 = np.array(test_data_12)[:FLAGS.num_samples]
    test_labels_12 = np.array(test_labels_12)[:FLAGS.num_samples]

    test_data_big_clauses, test_labels_big_clauses = get_data('bigger_clauses_test.txt', 'bigger_clauses_test_labels.txt')
    test_data_big_clauses = np.array(test_data_big_clauses)[:FLAGS.num_samples]
    test_labels_big_clauses = np.array(test_labels_big_clauses)[:FLAGS.num_samples]

    test_data_big_vars, test_labels_big_vars = get_data('bigger_clauses_vars.txt', 'bigger_clauses_vars_labels.txt')
    test_data_big_vars = np.array(test_data_big_vars)[:FLAGS.num_samples]
    test_labels_big_vars = np.array(test_labels_big_vars)[:FLAGS.num_samples]

    test_dataset_big_clauses = CreateDataset(np.transpose(test_data_big_clauses,(0,3,1,2)), np.transpose(test_labels_big_clauses,(0,3,1,2)))
    test_loader_big_clauses = torch.utils.data.DataLoader(dataset=test_dataset_big_clauses,
                                            batch_size=train_batch_size,
                                            shuffle=False, drop_last=True)
    test_dataset_12 = CreateDataset(np.transpose(test_data_12,(0,3,1,2)), np.transpose(test_labels_12,(0,3,1,2)))
    test_loader_12 = torch.utils.data.DataLoader(dataset=test_dataset_12,
                                            batch_size=train_batch_size,
                                            shuffle=False, drop_last=True)
    test_dataset_big_vars = CreateDataset(np.transpose(test_data_big_vars,(0,3,1,2)), np.transpose(test_labels_big_vars,(0,3,1,2)))
    test_loader_big_vars = torch.utils.data.DataLoader(dataset=test_dataset_big_vars,
                                            batch_size=train_batch_size,
                                            shuffle=False, drop_last=True)

test_dataset = CreateDataset(np.transpose(test_data_11,(0,3,1,2)), np.transpose(test_labels_11,(0,3,1,2)))
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                            batch_size=train_batch_size,
                                            shuffle=False, drop_last=True)

test_dataset_13 = CreateDataset(np.transpose(test_data_13,(0,3,1,2)), np.transpose(test_labels_13,(0,3,1,2)))
test_loader_13 = torch.utils.data.DataLoader(dataset=test_dataset_13,
                                            batch_size=train_batch_size,
                                            shuffle=False, drop_last=True)

def get_predicted(inputs, outputs):
    outputs = outputs.clone()
    predicted = outputs.argmax(1)
    predicted = predicted.view(predicted.size(0), -1)
    return predicted

def validate_generated_solution(problem,solution):
    variables = 0
    for i in range(len(problem[0])):
      if problem[0][i] == 2:
        break
      else:
        variables += 1

    variables //= 2
    true_variables = set()

    for clause in solution:
        for literal in range(len(clause)):
            if clause[literal] == 1.0:
                if literal>=variables:
                    true_variables.add(-(literal-variables)-1)
                else:
                    true_variables.add(literal + 1)

    seen = set()
    for i in true_variables:
        if -1*i in seen:
            return False
        seen.add(i)

    cared_about = set()
    for clause in problem:
        if 1 not in clause:
          break
        state = False
        for literal in range(len(clause)):
            if clause[literal] == 1:
                if literal>=variables:
                    if -(literal-variables)-1 in true_variables:
                        state = True
                else:
                    if literal + 1 in true_variables:
                        state = True
        if state == False:
            return False
        
    for clause in range(len(problem)):
        if 1 not in problem[clause]:
          break
        for literal in range(len(problem[0])):
            if problem[clause][literal] == 1.0:
                if literal < variables and literal+1 in true_variables:
                    if solution[clause][literal] != 1.0:
                        return False
                    else:
                        cared_about.add(literal+1)
                elif literal>=variables and  -(literal-variables)-1 in true_variables:
                    if solution[clause][literal] != 1.0:
                        return False
                    else:
                        cared_about.add(-(literal-variables)-1)
            elif problem[clause][literal] == 0:
                if solution[clause][literal] !=0:
                  return False
    return True

def test(model, dataset):
    iters = [i for i in range(1,91)]
    max_iters = max(iters)
    model.eval()
    corrects = torch.zeros(max_iters)
    total = 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataset, leave=False):
            inputs, targets = inputs.cuda(), targets.cuda()
            if FLAGS.type_2:
                b, c, h, w = targets.shape
                assert w % 2 == 0, "Width should be even to split it in half."
                first_half = targets[:, :, :, :w//2]
                second_channel_second_half = targets[:, 1:2, :, w//2:]
                targets = torch.cat([first_half, second_channel_second_half], dim=1)
                targets = targets.argmax(dim=1)
                targets = targets.view(targets.size(0), -1)
                
                b, c, h, w = inputs.shape
                assert w % 2 == 0, "Width should be even to split it in half."
                w_half = w//2
                first_half = inputs[:, :, :, :w//2]
                second_channel_second_half = inputs[:, 1:2, :, w//2:]
                inputs = torch.cat([first_half, second_channel_second_half], dim=1)
            else:
                targets = targets.argmax(dim=1)
                targets = targets.view(targets.size(0), -1)

            all_outputs = model(inputs, iters_to_do=max_iters)
            for i in range(0,all_outputs.size(1),5):
                outputs = all_outputs[:, i]
                predicted = get_predicted(inputs, outputs)
                for k in range(train_batch_size):
                    if validate_generated_solution(inputs.argmax(1)[k], predicted[k].view(inputs[0][0].size(0),-1)):
                      targets[k] = predicted[k]
                corrects[i] += torch.amin(predicted == targets, dim=[1]).sum().item()

            # print(torch.amax(corrects))
            total += targets.size(0)

    accuracy = 100.0 * corrects / total
    ret_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()
    return ret_acc

def test_model(model, datasets):
    accs = []
    for loader in datasets:
        accuracy = test(model, loader)
        accs.append(accuracy)
    return accs

def get_model(width=512, in_channels=3, max_iters=40):
    if FLAGS.type_2:
        net = DTNet_4(BasicBlock, [2], width=width, in_channels=in_channels, recall=True)
    else:
        net = DTNet(BasicBlock, [2], width=width, in_channels=in_channels, recall=True)
    return net

def load_model_from_checkpoint(device, id):
    model_path = id+f"/model_{'best' if FLAGS.best else ''}.pth"
    width = FLAGS.width
    max_iters = FLAGS.max_iters
    epoch = 0
    optimizer = None
    in_channels = FLAGS.in_channels
    net = get_model(width, in_channels=in_channels, max_iters=max_iters)
    net = net.cuda(device=device)
    if device == torch.device("cuda:0"):
        print("parralled")
        net = torch.nn.DataParallel(net)
    if model_path is not None:
        state_dict = torch.load(model_path, map_location=device)
        net.load_state_dict(state_dict["net"])

    return net, epoch, optimizer

model, ep, optimizer_state = load_model_from_checkpoint(device, id)

print("beggining testing:")
test_acc_11 = test_model(model, [test_loader])
test_acc_13 = test_model(model, [test_loader_13])
val_acc = test_model(model,[validation_loader])
train_acc = test_model(model,[train_loader])

if FLAGS.big_run:
    test_acc_12 = test_model(model, [test_loader_12])
    test_acc_big_clauses = test_model(model,[test_loader_big_clauses])
    test_acc_big_vars = test_model(model, [test_loader_big_vars ])

acc_dict = {'test_acc_11': test_acc_11, 'test_acc_13': test_acc_13, 'val_acc':val_acc, 'train_acc':train_acc}
if FLAGS.big_run:
    acc_dict = {'test_acc_11': test_acc_11, 'test_acc_12': test_acc_12, 'test_acc_13': test_acc_13, 'val_acc':val_acc, 'train_acc':train_acc, 'big_vars_acc':test_acc_big_vars, 'big_clauses_acc':test_acc_big_clauses}

if FLAGS.best:
    with open(f"{id}/stats_best.json", "w") as outfile:
        json.dump(acc_dict, outfile)
elif FLAGS.big_run:
    with open(f"{id}/stats_big.json", "w") as outfile:
        json.dump(acc_dict, outfile)
else:
    with open(f"{id}/stats_.json", "w") as outfile:
        json.dump(acc_dict, outfile)

print("completed")