import os
import time
import json
import wandb
import torch
import random
import argparse
from torch.optim import AdamW
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
torch.autograd.set_detect_anomaly(True)
from transformers import set_seed
from locolms.models.loco.model import LoCoLM
from locolms.models.loco.trainer import Trainer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
set_seed(1337)

def ddp_setup(rank, world_size:int, port:int):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    init_process_group(backend="nccl", rank=rank, world_size=world_size)

# =================================================

def main(rank:int, world_size:int, config:object, wandb_run:object, run_parallel:bool, port:int=12355):

    if run_parallel: ddp_setup(rank, world_size=world_size, port=port)
    if "quantization" not in config: config["quantization"] = False

    model = LoCoLM(
        quantization=config["quantization"],
        model_hf_name=config["model"], 
        gpu_id=rank
    )

    # Logical model
    if config["loss"] == "wmc":
        print("[+] Loading WMC logical module")
        from locolms.circuits.wmc import WMC
        constraints = WMC(eps=1e-20)

    elif config["loss"] == "mc":
        print("[+] Loading MC logical module")
        from locolms.circuits.mc import MC
        constraints = MC(eps=1e-20)
    
    optimizer = AdamW(model.parameters(), lr = config["lr"])

    # Load model
    if "checkpoint" in config:
        print(f"[-] Model checkpoint to load: {config['checkpoint']}")
        checkpoint = torch.load(config["checkpoint"], map_location="cpu")
        model.model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    model.to(rank)

    print("[-] Running trainer...")
    trainer = Trainer(
        model=model,
        lr=config["lr"],
        constraint_mg=constraints, 
        wandb=wandb_run,
        checkpoints_path=os.path.join("checkpoints"),
        optimizer=optimizer,
        gpu_id=rank,
        config=config,
        run_parallel=run_parallel,
        val_interval=config["val_interval"],
        dataset=config["dataset"]
    )

    # Training mode
    if config["task"] == "train":
        trainer.run_train()
    # Eval
    elif config["task"] == "eval":
        start = time.time()
        trainer.run_eval()
        end = time.time()
        print(f"[-] Time elapsed: {end-start}s")
    else:
        raise Exception("Invalid task seleted.")

    # Cleanup
    del model
    del optimizer
    del constraints
    del trainer
    torch.cuda.empty_cache()
    if run_parallel: destroy_process_group()

# =================================================

if __name__ == '__main__':

    # Configuration
    parser = argparse.ArgumentParser(
                        prog='python train.py',
                        description='LOgically Consistent Language Models. (2024) Diego Calanzone, Stefano Teso, Antonio Vergari.')

    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--constraint_type', type=str, choices=["implication", "inverse_implication", "negation", "all"], default="all")
    parser.add_argument('--port', type=str, default=12355)
    parser.add_argument('--run_name', type=str)
    parser.add_argument('--model', type=str)
    parser.add_argument('--checkpoint', type=str)
    parser.add_argument('--quantization', type=bool)
    parser.add_argument('--lr_scheduler', type=bool)
    parser.add_argument('--patience', type=int, default=3)
    parser.add_argument('--loss', type=str, default="wmc", choices=["wmc", "mc"])
    parser.add_argument('--prompt_format', type=str, default="zero_shot", choices=["zero_shot", "few_shot"])
    parser.add_argument('--dataset', type=str, choices=["beliefbank", "conceptnet"], default="beliefbank")
    args = parser.parse_args()

    with open(args.config, "r") as f:
        config = json.load(f)

    # Wandb setup
    if config["wandb"] == True:
        wandb.login()
        wandb_run = wandb.init(
            project="large-semantic-language-models",
            name=args.run_name,
            config={
                "model": config["model"],
                "learning_rate": config["lr"],
                "batch_size": config["batch_size"],
                "accumulation_steps": config["accumulation_steps"],
                "epochs": config["epochs"]
            }
        )
    else: wandb_run = None

    if args.constraint_type is not None:
        config["constraint_type"] = args.constraint_type

    if args.model is not None:
        config["model"] = args.model

    if args.quantization is not None:
        config["quantization"] = args.quantization

    if args.checkpoint is not None:
        config["checkpoint"] = args.checkpoint

    if args.lr_scheduler is not None: config["lr_scheduler"] = True
    else: config["lr_scheduler"] = False
    
    config["loss"] = args.loss
    config["dataset"] = args.dataset
    config["patience"] = args.patience
    config["prompt_format"] = args.prompt_format

    if config["parallel"] is not True:
        print("[-] Running on single gpu")
        main(rank=0, world_size=1, config=config, wandb_run=wandb_run, run_parallel=False, port=args.port)
    else:
        ngpus = torch.cuda.device_count()
        print(f"[-] Running on {ngpus} gpus")
        mp.spawn(main, args=(ngpus, config, wandb_run, True, args.port), nprocs=ngpus)
