from pathlib import Path
from argparse import ArgumentParser, Namespace
from sample.sampler import Sampler
from handle.handler import Handler
from util.scenario import Scenario
from util.interface import CausalModel, VideoGenerator
from util.util import register_scenarios
from handle.prompt_generator import PromptGeneratorV1
    
class NoneVideoGenerator(VideoGenerator):
    def __init__(self):
        super().__init__()

    def generate_video_from_multi_prompts(self, prompts_file, save_paths, **kwargs):
        pass

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Sample and generate prompts.")
    parser.add_argument("--sample_n_level_1", type=int, default=100, help="Number of samples for level 1")
    parser.add_argument("--group_num_level_2", type=int, default=16, help="Number of groups for level 2")
    parser.add_argument("--repeat_num_level_2", type=int, default=5, help="Repeat times for each group for level 2")
    parser.add_argument("--sample_mode_level_2", type=str, default="uniform", choices=["uniform", "sample efficiency"],
                         help="uniform: samples for level 2 are uniformly drawn from all possible values of root variables. sample efficiency: firstly draw sample form existing samples")
    parser.add_argument("--sample_n_level_3", type=int, default=50, help="Number of False/True samples for each non_root variable for level 3")
    return parser.parse_args()

def main():
    args: Namespace= parse_args()
    json_file_path = Path("__file__").resolve().parent / "config" / "samples"
    json_files = [json_file_path / "large4_0.json", json_file_path / "large13_2.json"]
    for json_file in json_files:
        Scenario.add_scenario_from_json(json_file_path=json_file)
    video_generator = NoneVideoGenerator()
    prompt_generate_kwargs = {} # TODO: read keyword arguments for generating prompts
    scenarios = ["large4_0%A burning candle is placed with wind and rain.",
                 "large13_2%A ray of light is shining on a wooden block."]
    for scenario in scenarios:
        if scenario == "test scenario":
            continue
        model = CausalModel(scenario=scenario)
        sampler = Sampler(model=model)
        sampler.generate_sample_for_rules(repeat=args.sample_n_level_3)
        sampler.generate_sample_by_text_consistency(repeat=args.sample_n_level_1)
        sampler.generate_sample_gen_consistency(repeat=args.repeat_num_level_2,
                                                sample_num=args.group_num_level_2,
                                                mode=args.sample_mode_level_2)
        prompt_generator = PromptGeneratorV1(model=model)
        handler = Handler(model=model, prompt_generator=prompt_generator,
                          video_generator=video_generator, sample_n=args.sample_n_level_3)
        handler.generate_prompts(**prompt_generate_kwargs)
        handler.generate_videos_multi_prompts(do_level_1=True)

if __name__ == '__main__':
    main()