import os
import torch
import math
from random import randrange
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import numpy as np
from CNNDT_Architecture import DTNet
from blocks import BasicBlock2D as BasicBlock
from tqdm import tqdm
from torch.optim import SGD, Adam, AdamW
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
from torch.optim import Optimizer
import logging
import json
from collections import OrderedDict
from almost_unique_id import generate_id
import argparse

parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("--logdir", default="default-log-dir", type=str,
                    help="location where log of experiments will be stored")
parser.add_argument("--exp", default="default", type=str, help="name of experiments")
parser.add_argument("--model_path", default=None, type=str, help="model to pick up")
parser.add_argument("--epochs", default=10, type=int, help="size of batch of input to use")
parser.add_argument("--max_iters", default=40, type=int, help="number of iterations in DT model when training.")
parser.add_argument("--batch_size", default=5, type=int, help="batch_size")
parser.add_argument("--in_channels", default=3, type=int, help="num channels")
parser.add_argument("--width", default=512, type=int, help="num channels")
FLAGS = parser.parse_args() 
## HYPERPARAMS ##
folder_name = FLAGS.logdir
if FLAGS.exp == "default":
    run_id = generate_id()
else:
    run_id = FLAGS.exp
id = "outputs/"+folder_name +"/"+ run_id
os.makedirs(id)
train_batch_size = FLAGS.batch_size
epochs = FLAGS.epochs
max_iters = FLAGS.max_iters
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(f"run id = {id}")
print(f"max iters  = {max_iters}")
print(f"num epochs = {epochs}")
print(f"batch size = {train_batch_size}")

class CreateDataset(torch.utils.data.Dataset):
    def __init__(self, image_data, count_data, transform=None):
        self.image_data = torch.from_numpy(image_data).to(torch.float)
        self.count_data = torch.from_numpy(count_data).to(torch.float)

    def __getitem__(self, index):
        image = self.image_data[index]
        count = self.count_data[index]
        return image, count

    def __len__(self):
        return len(self.image_data)

vals = {0:[1,0,0], 1:[0,1,0], 2:[0,0,1]}
def get_data(problem_file, label_file):
    X = []
    with open(problem_file, 'r') as f:
        problem = []
        for line in f:
            if line != '\n':
                problem.append([vals[int(i)] for i in line[:-1].split()])
            else:
                X.append(problem)
                problem = []
    Y = []
    with open(label_file, 'r') as f:
        label = []
        for line in f:
            if line != '\n':
                label.append([vals[int(i)] for i in line[:-1].split()])
            else:
                Y.append(label)
                label = []
    return X,Y

training_data, training_labels = get_data('variable_training_instances.txt', 'variable_training_instances_labels.txt')
validation_data, validation_labels = get_data('variable_validation_instances.txt', 'variable_validation_instances_labels.txt')

training_data = np.array(training_data)
training_labels = np.array(training_labels)
validation_data = np.array(validation_data)[:1000]
validation_labels = np.array(validation_labels)[:1000]

train_dataset = CreateDataset(np.transpose(training_data,(0,3,1,2)),np.transpose(training_labels,(0,3,1,2)))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=train_batch_size,
                                            shuffle=True, drop_last=True)

validatation_dataset = CreateDataset(np.transpose(validation_data,(0,3,1,2)), np.transpose(validation_labels,(0,3,1,2)))
validation_loader = torch.utils.data.DataLoader(dataset=validatation_dataset,
                                            batch_size=train_batch_size,
                                            shuffle=True, drop_last=True)

# test_data, test_labels = get_data('11x11_test.txt', '11x11_test_labels.txt')
# test_data = np.array(test_data)
# test_labels = np.array(test_labels)

# test_dataset = CreateDataset(np.transpose(test_data,(0,3,1,2)), np.transpose(test_labels,(0,3,1,2)))
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                             batch_size=train_batch_size,
#                                             shuffle=False, drop_last=True)

def get_optimizer(net, max_iters):
    optimizer_name = "adam"
    lr = 0.001
    lr_decay = "step"
    lr_schedule = []
    lr_factor = None
    warmup_period = 50

    # Reducing the lr here for the recurrent layers helps with stability,
    # To date (July 21, 2021), we may only need this for maze models.
    base_params = [p for n, p in net.named_parameters() if "recur" not in n]
    recur_params = [p for n, p in net.named_parameters() if "recur" in n]
    iters = max_iters
    all_params = [{"params": base_params}, {"params": recur_params, "lr": lr / iters}]

    if optimizer_name == "sgd":
        optimizer = SGD(all_params, lr=lr, weight_decay=2e-4, momentum=0.9)
    elif optimizer_name == "adam":
        optimizer = Adam(all_params, lr=lr, weight_decay=2e-4)
    elif optimizer_name == "adamw":
        optimizer = AdamW(all_params, lr=lr, weight_decay=2e-4)

    # warmup_scheduler = ExponentialWarmup(optimizer, warmup_period=warmup_period)
    # lr_scheduler = MultiStepLR(optimizer, milestones=lr_schedule, gamma=lr_factor, last_epoch=-1)

    return optimizer

def get_output_for_prog_loss(inputs, max_iters, net):
    # get features from n iterations to use as input
    n = randrange(0, max_iters)

    # do k iterations using intermediate features as input
    k = randrange(1, max_iters - n + 1)

    if n > 0:
        _, interim_thought = net(inputs, iters_to_do=n)
        interim_thought = interim_thought.detach()
    else:
        interim_thought = None

    outputs, _ = net(inputs, iters_elapsed=n, iters_to_do=k, interim_thought=interim_thought)
    return outputs, k

def get_predicted(inputs, outputs):
    outputs = outputs.clone()
    predicted = outputs.argmax(1)
    predicted = predicted.view(predicted.size(0), -1)
    return predicted

def validate_generated_solution(problem,solution):
    variables = 0
    for i in range(len(problem[0])):
      if problem[0][i] == 2:
        break
      else:
        variables += 1

    variables //= 2
    true_variables = set()

    for clause in solution:
        for literal in range(len(clause)):
            if clause[literal] == 1.0:
                if literal>=variables:
                    true_variables.add(-(literal-variables)-1)
                else:
                    true_variables.add(literal + 1)

    seen = set()
    for i in true_variables:
        if -1*i in seen:
            return False
        seen.add(i)

    cared_about = set()
    for clause in problem:
        if 1 not in clause:
          break
        state = False
        for literal in range(len(clause)):
            if clause[literal] == 1:
                if literal>=variables:
                    if -(literal-variables)-1 in true_variables:
                        state = True
                else:
                    if literal + 1 in true_variables:
                        state = True
        if state == False:
            return False
        
    for clause in range(len(problem)):
        if 1 not in problem[clause]:
          break
        for literal in range(len(problem[0])):
            if problem[clause][literal] == 1.0:
                if literal < variables and literal+1 in true_variables:
                    if solution[clause][literal] != 1.0:
                        return False
                    else:
                        cared_about.add(literal+1)
                elif literal>=variables and  -(literal-variables)-1 in true_variables:
                    if solution[clause][literal] != 1.0:
                        return False
                    else:
                        cared_about.add(-(literal-variables)-1)
            elif problem[clause][literal] == 0:
                if solution[clause][literal] !=0:
                  return False
    return True

def train_model(model, dataset, optimizer, max_iters):
    alpha = 0.01
    clip = None
    criterion = torch.nn.CrossEntropyLoss(reduction = 'none')
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    num_changed = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(dataset, leave=False)):
        inputs, targets = inputs.cuda(), targets.cuda()
        targets = targets.argmax(dim=1)
        targets = targets.view(targets.size(0), -1)
        mask = inputs.view(inputs.size(0), inputs.size(1), -1).max(dim=1)[0] > 0

        optimizer.zero_grad()

        # get fully unrolled loss if alpha is not 1 (if it is 1, this loss term is not used
        # so we save time by settign it equal to 0).
        outputs_max_iters, _ = model(inputs, iters_to_do=max_iters)

        outputs_max_iters = outputs_max_iters.view(outputs_max_iters.size(0),outputs_max_iters.size(1), -1)

        predicted = get_predicted(inputs, outputs_max_iters)

        for i in range(train_batch_size):
            if validate_generated_solution(inputs.argmax(1)[i], predicted[i].view(inputs[0][0].size(0),-1)):
                targets[i] = predicted[i]

                num_changed += 1
        if alpha != 1:
            loss_max_iters = criterion(outputs_max_iters, targets)
        else:
            loss_max_iters = torch.zeros_like(targets).float()

        # get progressive loss if alpha is not 0 (if it is 0, this loss term is not used
        # so we save time by setting it equal to 0).
        if alpha != 0:
            outputs, k = get_output_for_prog_loss(inputs, max_iters, model)
            outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
            loss_progressive = criterion(outputs, targets)
        else:
            loss_progressive = torch.zeros_like(targets).float()

        loss_max_iters_mean = loss_max_iters.mean()
        loss_progressive_mean = loss_progressive.mean()

        loss = (1 - alpha) * loss_max_iters_mean + alpha * loss_progressive_mean
        loss.backward()

        if clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        train_loss += loss.item()

        correct += torch.amin(predicted == targets, dim=[-1]).sum().item()
        total += targets.size(0)

    acc = 100.0 * correct / total

    return acc, train_loss

def test(model, dataset):
    iters = [i for i in range(1,91)]
    testloader = train_loader
    max_iters = max(iters)
    model.eval()
    corrects = torch.zeros(max_iters)
    total = 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataset, leave=False):
            inputs, targets = inputs.cuda(), targets.cuda()
            targets = targets.argmax(dim=1)
            all_outputs = model(inputs, iters_to_do=max_iters)
            for i in range(0,all_outputs.size(1),5):
                outputs = all_outputs[:, i]
                predicted = get_predicted(inputs, outputs)
                targets = targets.view(targets.size(0), -1)
                for k in range(train_batch_size):
                    if validate_generated_solution(inputs.argmax(1)[k], predicted[k].view(inputs[0][0].size(0),-1)):
                      targets[k] = predicted[k]
                corrects[i] += torch.amin(predicted == targets, dim=[1]).sum().item()

            print(torch.amax(corrects))
            total += targets.size(0)

    accuracy = 100.0 * corrects / total
    ret_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()
    return ret_acc

def test_model(model, datasets):
    accs = []
    for loader in datasets:
        accuracy = test(model, loader)
        accs.append(accuracy)
    return accs

def get_model(width=512, in_channels=3, max_iters=40):
    net = DTNet(BasicBlock, [2], width=width, in_channels=in_channels, recall=True)
    return net

def load_model_from_checkpoint(device):
    model_path = FLAGS.model_path
    width = FLAGS.width
    epoch = 0
    in_channels = FLAGS.in_channels
    net = get_model(width, in_channels=in_channels, max_iters=FLAGS.max_iters)
    net = net.cuda(device=device)
    if device == torch.device("cuda:0"):
        print("parralled")
        net = torch.nn.DataParallel(net)
    optimizer = get_optimizer(net, max_iters)
    if model_path is not None:
        state_dict = torch.load(f"{model_path}/model_.pth", map_location=device)
        net.load_state_dict(state_dict["net"])
        epoch = state_dict["epoch"] + 1
        optimizer_state = state_dict["optimizer"]
        optimizer.load_state_dict(optimizer_state)

    return net, epoch, optimizer

model, ep, optimizer = load_model_from_checkpoint(device)
if FLAGS.model_path is None:
    optimizer = get_optimizer(model, max_iters)

highest_val_acc_so_far = -1
for epoch in range(ep,epochs):
    print("starting epoch ", epoch)
    best_so_far = False
    acc, loss = train_model(model, train_loader, optimizer, max_iters)
    print(acc, loss)
    torch.cuda.empty_cache()

    val_acc = test_model(model, [validation_loader])[0][max_iters]

    state = {"net": model.state_dict(), "epoch": epoch, "optimizer": optimizer.state_dict()}
    if val_acc > highest_val_acc_so_far:
        print("best")
        best_so_far = True
        highest_val_acc_so_far = val_acc
    out_str = f"{id}/model_{'best' if best_so_far else ''}.pth"
    torch.save(state, out_str)

    if (epoch+1)%10 == 0:
        state = {"net": model.state_dict(), "epoch": epoch, "optimizer": optimizer.state_dict()}
        out_str = f"{id}/model_{str(epoch)}.pth"
        torch.save(state, out_str)

print("completed")