from argparse import ArgumentParser, Namespace
from sample.sampler import Sampler
from handle.handler import Handler
from util.scenario import Scenario
from util.interface import CausalModel, VideoGenerator
from util.util import register_scenarios
from handle.prompt_generator import PromptGeneratorV1
    
class NoneVideoGenerator(VideoGenerator):
    def __init__(self):
        super().__init__()

    def generate_video_from_multi_prompts(self, prompts_file, save_paths, **kwargs):
        pass

def parse_args() -> Namespace:
    parser = ArgumentParser(description="Sample and generate prompts.")
    parser.add_argument("--sample_n_level_1", type=int, default=10, help="Number of samples for level 1")
    parser.add_argument("--group_num_level_2", type=int, default=5, help="Number of groups for level 2")
    parser.add_argument("--repeat_num_level_2", type=int, default=3, help="Repeat times for each group for level 2")
    parser.add_argument("--sample_mode_level_2", type=str, default="uniform", choices=["uniform", "sample efficiency"],
                         help="uniform: samples for level 2 are uniformly drawn from all possible values of root variables. sample efficiency: firstly draw sample form existing samples")
    parser.add_argument("--sample_n_level_3", type=int, default=10, help="Number of False/True samples for each non_root variable for level 3")
    return parser.parse_args()

def main():
    args: Namespace= parse_args()
    register_scenarios()
    video_generator = NoneVideoGenerator()
    prompt_generate_kwargs = {} # TODO: read keyword arguments for generating prompts
    all_scenarios = Scenario.get_all_scenarios()
    for scenario in all_scenarios:
        if scenario == "test scenario":
            continue
        model = CausalModel(scenario=scenario)
        sampler = Sampler(model=model)
        sampler.generate_sample_for_rules(repeat=args.sample_n_level_3)
        sampler.generate_sample_by_text_consistency(repeat=args.sample_n_level_1)
        sampler.generate_sample_gen_consistency(repeat=args.repeat_num_level_2,
                                                sample_num=args.group_num_level_2,
                                                mode=args.sample_mode_level_2)
        prompt_generator = PromptGeneratorV1(model=model)
        handler = Handler(model=model, prompt_generator=prompt_generator,
                          video_generator=video_generator, sample_n=args.sample_n_level_3)
        handler.generate_prompts(**prompt_generate_kwargs)
        handler.generate_videos_multi_prompts(do_level_1=True)

if __name__ == '__main__':
    main()