import os
import torch
import hydra
import sys
import subprocess
from omegaconf import DictConfig

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import generate_mathematical_expressions, generate_classifier_data, generate_vocab
from multiguide.training.helpers import train_classifier, setup_translator, translate, evaluate_translations

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@hydra.main(config_path="../configs", config_name="config.yaml")
def main(config: DictConfig):
    '''
        Run the toy experiment: simplifying mathematical expressions to different lengths. 
        The code works in 3 sequential phases, which need to be specified through the hydra variable classifier_guidance.phase
        (in toy_experiment.yaml)

        - generate_data: generate the mathematical expressions and the classifier data used to train and evaluate the onmt seq2seq model,
          the classifier model and their combination (guided model).
        - train_classifier: train the classifier model.
        - translate: translate the mathematical expressions using either a guided or regular seq2seq model.
    '''
    data_dir = os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir)
    os.makedirs(data_dir, exist_ok=True)
    if config.classifier_guidance.phase == "generate_data":
        print(f'Generating # {config.classifier_guidance.dataset.max_num_expressions} mathematical expressions...')
        generate_mathematical_expressions(max_num_expressions=config.classifier_guidance.dataset.max_num_expressions, 
                                          max_depth=config.classifier_guidance.dataset.max_depth, config=config)
        print(f'Generated mathematical expressions for {config.classifier_guidance.dataset.data_dir}')
        generate_classifier_data(tgt_file=os.path.join(data_dir, "train.tgt"),
                                 # TODO: are both of these limits needed?
                                 completion_lower_limit=config.classifier_guidance.dataset.partial_sequence_completion_lower_limit,
                                 min_length_limit=config.classifier_guidance.dataset.partial_sequence_min_length_limit,
                                 config=config)
        print(f'Generated classifier data for {config.classifier_guidance.dataset.data_dir}')
        generate_vocab(vocab_file=os.path.join(data_dir, "vocab.txt"),
                       src_file=os.path.join(data_dir, "train.src"),
                       tgt_file=os.path.join(data_dir, "train.tgt"))
    elif config.classifier_guidance.phase == "train_onmt":
        gpu_ranks = "--gpu_ranks 0" if torch.cuda.is_available() else ""  # or however you define this
        train_config_path = os.path.join(PROJECT_ROOT, 
                                         "configs", 
                                         "toy_experiment_onmt", 
                                         "train.yml")
        cmd = f"python3 -m onmt.bin.train -config {train_config_path} {gpu_ranks}"
        # Run it
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"ONMT training failed: {result.stderr}")
            sys.exit(1)
        else:
            print("ONMT training completed successfully")
    elif config.classifier_guidance.phase == "train_classifier":
        print(f'Training classifier for {config.classifier_guidance.dataset.data_dir}')
        train_classifier(src_file=os.path.join(data_dir, "classifier_train.src"),
                         lengths_file=os.path.join(data_dir, "classifier_train.tgt"),
                         onmt_checkpoint_path=config.classifier_guidance.onmt_checkpoint_path,
                         config=config)
        print(f'Trained classifier for {config.classifier_guidance.dataset.data_dir}')
    elif config.classifier_guidance.phase == "translate":
        print(f'Translating {config.classifier_guidance.dataset.data_dir}')
        translator, opt = setup_translator(classifier_config=config)
        with open(os.path.join(data_dir, "test.src"), "r") as f:
            src_lines = f.readlines()
        print(f'Translating {config.classifier_guidance.experiment_name}')
        print(f'src_lines: {src_lines[:10]}')
        src_lines, all_predictions, all_scores = translate(src_lines=src_lines, translator=translator, opt=opt)
        evaluate_translations(source_lines=src_lines, all_predictions=all_predictions, all_scores=all_scores, config=config)
    else:
        raise ValueError(f'Invalid phase: {config.classifier_guidance.phase}')

if __name__ == "__main__":
    main()