import os
import torch
from tqdm import tqdm
from torch.optim import SGD, Adam
from torch.utils.data.dataloader import DataLoader

from core.data import GraphDataset, SizeBasedSampler, get_net_size_fn
from core.utils import set_seed
from core import logic_net


def induction(positive_graphs, negative_graphs, args,
              shuffle=False, asserting_debug_fn=None):
    all_graphs = positive_graphs + negative_graphs
    for _graph in all_graphs:
        _graph.to(args.device)
    n_pos = len(positive_graphs)
    n_neg = len(negative_graphs)
    n_predicates = len(all_graphs[0].attribute_names +
                       all_graphs[0].relation_names)
    graph_dataset = GraphDataset(
        all_graphs, [1] * n_pos + [-1] * n_neg,
        cached_num=args.cached_num, device=args.device
    )

    set_seed(args.seed)
    size_fn = get_net_size_fn(n_predicates, args)
    sampler = SizeBasedSampler([graph.n_nodes for graph in all_graphs],
                               size_fn=size_fn,
                               total_size=args.batch_total_size,
                               shuffle=shuffle)
    dataloader = DataLoader(
        graph_dataset, collate_fn=graph_dataset.collate_fn,
        batch_sampler=sampler, num_workers=args.num_workers, pin_memory=False
    )
    if args.deduce_attribute is not None:
        soft_equation = logic_net.DeduceNet.from_graphs(
            all_graphs, args.n_variables, args.n_layers, args.n_width,
            args.deduce_attribute, args.init_var, args.temperature, args.seed)
    elif args.deduce_relation is not None:
        soft_equation = logic_net.DeduceRelationNet.from_graphs(
            all_graphs, args.n_variables, args.n_layers, args.n_width,
            args.deduce_relation, args.init_var, args.temperature, args.seed)
    else:
        soft_equation = logic_net.LogicNet.from_graphs(
            all_graphs, args.n_variables, args.n_layers, args.n_width,
            args.init_var, args.temperature, args.seed)
    soft_equation.to(args.device)
    if args.optimizer == "SGD":
        optimizer = SGD(soft_equation.parameters(), lr=args.lr)
    elif args.optimizer == "Adam":
        optimizer = Adam(soft_equation.parameters(), lr=args.lr)
    start_step = 0
    if args.ckpt_file is not None and os.path.exists(args.ckpt_file):
        if not args.abort_ckpt:
            args.print("loading ckpt from: " + args.ckpt_file)
            ckpt = torch.load(args.ckpt_file)
            soft_equation.load_state_dict(ckpt[0])
            optimizer.load_state_dict(ckpt[1])
            start_step = ckpt[2]
        else:
            args.print(f"ckpt found but overwritten at {args.ckpt_file}")
    if args.asserting_debug:
        asserting_debug_fn(soft_equation)

    all_labels = torch.tensor(graph_dataset.labels)
    all_scores = (all_labels + 1) / 2
    all_losses = torch.zeros(len(graph_dataset))
    pbar = tqdm(total=args.n_epochs, initial=start_step)
    while pbar.n <= args.n_epochs:
        for batch, y, indices in dataloader:
            if batch[0] is None and batch[2] is None:
                continue
            optimizer.zero_grad()
            loss, scores, score_tensor = soft_equation.classify(
                batch, y, args.soft_logic, args.final_forall_logic,
                args.loss_fn, args.distinct_variables,
            )
            entropy = soft_equation.entropy()
            all_losses[indices] = loss.cpu().detach()
            all_scores[indices] = scores.cpu().detach()

            lambda_entropy = 1 if not args.entropy_reg_increasing != "none" \
                else pbar.n / args.n_epochs
            if args.entropy_reg_increasing == "square":
                lambda_entropy = lambda_entropy ** 2
            total_loss = loss.sum() + \
                entropy * args.entropy_regularization * lambda_entropy
            total_loss.backward()
            optimizer.step()

        if pbar.n % args.log_interval == 0:
            args.print(f"======== step {pbar.n} =======")
            args.print(soft_equation)
            if args.print_parameters:
                args.print(*soft_equation.operator_parameters(), sep="\n")
            if pbar.n > 0:
                args.print("mean positive scores:",
                           all_scores[all_labels == 1].mean())
                args.print("mean negative scores:",
                           all_scores[all_labels == -1].mean())
                args.print("mean loss:", all_losses.mean())
                args.print("entropy:", entropy.cpu().detach())
                args.print("positive correct rate:",
                           (all_scores[all_labels == 1] > 0.5)
                           .float().mean())
                if args.ckpt_file is not None:
                    torch.save([soft_equation.state_dict(),
                                optimizer.state_dict(),
                                pbar.n],
                               args.ckpt_file)
        pbar.update()

    args.print("Final Equation: ", soft_equation)
