import json
import numpy as np
import os
import random
import sys
from tqdm import tqdm

import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter

from mas_sat import \
    dataset_registry, graph_registry,\
    agent_registry,\
    model_registry, learner_registry
from mas_sat.learn.engine import Engine
from mas_sat.utils.metadata import dump_with_max_depth
from mas_sat.utils.state_dict import load_state_dict

def main():
    from mas_sat.utils.config import args

    # set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # set device
    if not torch.cuda.is_available():
        print("No GPU available, use CPU!")
        device = torch.device("cpu")
    elif args.gpu is None:
        print("No GPU specified, use CPU!")
        device = torch.device("cpu")
    else:
        print("Use GPU:{}".format(args.gpu))
        device = torch.device("cuda:{}".format(args.gpu))

    # prepare dataset
    cnf_dataset = dataset_registry.get(args.dataset)(args)

    # prepare graph
    graph = graph_registry.get(args.graph)

    # prepare environment
    env = gym.make("kissat-v0.1", dataset=cnf_dataset, args=args, disable_env_checker=True)

    def generate():
        cnf_dataset.generate()

    def solve():
        cnf_dataset.solve()

    def train():
        assert args.agent == "model",\
            "Can only train model-based agents, not {}.".format(args.agent)

        # prepare writer
        writer = SummaryWriter(args.experiment_dir)
        writer.add_text("command", " ".join(sys.argv))
        writer.add_text("args", str(args))

        # prepare model
        model = model_registry.get(args.model)(args)
        model.to(device)

        # perpare agent
        agent = agent_registry.get(args.agent)(model, args)
        agent.train()

        # prepare learner
        learner = learner_registry.get(args.learner)(model, device, writer, args)

        # prepare engine
        engine = Engine(cnf_dataset, graph, env, model, agent, device, writer, args)

        # load checkpoint if specified
        if args.finetune is not None:
            state_dict = torch.load(args.finetune, map_location=torch.device("cpu"))
            load_state_dict(model, state_dict["model_state_dict"])
        elif args.checkpoint is not None:
            checkpoint = os.path.join(args.experiment_dir, "checkpoint_{}.pth".format(args.checkpoint))
            state_dict = torch.load(checkpoint, map_location=torch.device("cpu"))
            load_state_dict(model, state_dict["model_state_dict"])

        # save checkpoint function
        def save_checkpoint(step):
            checkpoint = os.path.join(args.experiment_dir, "checkpoint_{}.pth".format(step))
            state_dict = {
                "model_state_dict": model.state_dict()
            }
            torch.save(state_dict, checkpoint)

        # train
        learned_step = learner.get_counter()
        evaluate_step = (learned_step // args.evaluate_interval + 1) * args.evaluate_interval
        if args.evaluate_mode is None:
            evaluate_modes = [args.mode]
            evaluate_steps = [args.num_step]
        else:
            evaluate_modes = args.evaluate_mode.split(",")
            evaluate_steps = [int(x) for x in args.num_evaluate_step.split(",")]
            assert len(evaluate_modes) == len(evaluate_steps)
        cnf_dataset.train()
        tbar = tqdm(total=args.learn_step, initial=learned_step, desc="Training")
        while learned_step < args.learn_step:
            engine.train(args.mode, learner)
            learn_step = learner.learn()
            tbar.update(learn_step)
            learned_step += learn_step
            # evaluate
            if learned_step >= evaluate_step:
                save_checkpoint(evaluate_step)
                cnf_dataset.valid()
                agent.eval()
                for mode, step in zip(evaluate_modes, evaluate_steps):
                    model.set_step(step)
                    engine.evaluate(mode, evaluate_step, record=True)
                cnf_dataset.train()
                agent.train()
                model.set_step(args.num_step)
                evaluate_step = (learned_step // args.evaluate_interval + 1) * args.evaluate_interval
        tbar.close()

    def evaluate():
        # prepare agent
        agent = agent_registry.get(args.agent)
        if agent.is_model_based():
            model = model_registry.get(args.model)(args)
            if args.checkpoint is not None:
                checkpoint = os.path.join(args.experiment_dir, "checkpoint_{}.pth".format(args.checkpoint))
                state_dict = torch.load(checkpoint, map_location=torch.device("cpu"))
                load_state_dict(model, state_dict["model_state_dict"])
                counter = args.checkpoint
            else:
                print("Warning: No checkpoint specified, use random model.")
                counter = 0
            model.to(device)
            agent = agent(model, args)
        else:
            model = None
            agent = agent()
            counter = 0

        # prepare engine
        engine = Engine(cnf_dataset, graph, env, model, agent, device, None, args)

        # evaluate
        cnf_dataset.test()
        for mode in args.mode.split(","):
            engine.evaluate(mode, counter)

    # perform train or eval
    if args.command == "generate":
        generate()
    elif args.command == "solve":
        solve()
    elif args.command == "train":
        train()
    elif args.command == "evaluate":
        evaluate()
    else:
        raise ValueError("Unrecognized command: {}.".format(args.command))

if __name__ == "__main__":
    main()
