import os
import sys
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

import torch.distributed as dist
from omegaconf import OmegaConf
import argparse
import torch
import time
import gc
import json
from einops import rearrange
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.transforms.functional as TF
from PIL import Image
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs
from tqdm import tqdm
import logging
import datasets
import transformers
import shutil
import diffusers
from diffusers.utils import export_to_video
from diffusers.training_utils import EMAModel
from copy import deepcopy

current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
    sys.path.insert(0, project_root) if project_root not in sys.path else None

from videox_fun.models.wan_dmd import DMD
from videox_fun.data.dataset import ImageVideoDataset, BucketDistributedSampler, cycle
from videox_fun.utils.utils import get_video_to_video_latent, resize_mask
from videox_fun.pipeline import WanFunInpaintPipeline, WanFunCausalInpaintPipeline
from videox_fun.pipeline import WanFunControlPipeline, WanFunCausalControlPipeline



class Trainer:
    def __init__(self, config):
        self.config = config

        logging_dir = os.path.join(config.output_dir, config.logging_dir)
        accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

        self.accelerator = accelerator = Accelerator(
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            mixed_precision=config.mixed_precision,
            log_with=config.report_to,
            project_config=accelerator_project_config,
            # kwargs_handlers=[ddp_kwargs],
        )
        deepspeed_plugin = accelerator.state.deepspeed_plugin
        if deepspeed_plugin is not None:
            zero_stage = int(deepspeed_plugin.zero_stage)
            print(f"Using DeepSpeed Zero stage: {zero_stage}")
        else:
            zero_stage = 0
            print("DeepSpeed is not enabled.")
        if accelerator.is_main_process:
            self.writer = SummaryWriter(log_dir=logging_dir)

        # Make one log on every process with the configuration for debugging.
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=logging.INFO,
        )
        logging.info(accelerator.state)
        if accelerator.is_local_main_process:
            datasets.utils.logging.set_verbosity_warning()
            transformers.utils.logging.set_verbosity_warning()
            diffusers.utils.logging.set_verbosity_info()
        else:
            datasets.utils.logging.set_verbosity_error()
            transformers.utils.logging.set_verbosity_error()
            diffusers.utils.logging.set_verbosity_error()

        # If passed along, set the training seed now.
        if config.seed is not None:
            set_seed(config.seed + accelerator.process_index)

        # Handle the repository creation
        if accelerator.is_main_process:
            if config.output_dir is not None:
                os.makedirs(config.output_dir, exist_ok=True)
        
        if accelerator.is_main_process:
            log_config = {}
            for k, v in vars(config).items():
                if isinstance(v, (int, float, str, bool, torch.Tensor)):
                    log_config[k] = v
                elif isinstance(v, dict):
                    for kk, vv in v.items():
                        log_config[f"{k}.{kk}"] = str(vv)
                else:
                    log_config[k] = str(v)
            accelerator.init_trackers(config.tracker_project_name, config=log_config)

        self.dtype = torch.bfloat16 if config.mixed_precision == 'bf16' else torch.float32
        self.device = torch.cuda.current_device()

        # Step 2: Initialize the model and optimizer
        self.distillation_model = DMD(config, device=self.device, dtype=self.dtype)

        if config.use_ema:
            ema_generator = deepcopy(self.distillation_model.generator)
            self.ema_generator = EMAModel(ema_generator.parameters(), model_cls=ema_generator.__class__, model_config=ema_generator.config)
            self.ema_generator.to(self.device)

        self.generator_optimizer = torch.optim.AdamW(
            [param for param in self.distillation_model.generator.parameters()
             if param.requires_grad],
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )
        critic_trainable_params = [param for param in self.distillation_model.fake_score.parameters() if param.requires_grad]
        if self.distillation_model.discriminator is not None:
            print(f"Discriminator is not None, add discriminator parameters to critic optimizer")
            critic_trainable_params += [param for param in self.distillation_model.discriminator.parameters() if param.requires_grad]
        self.critic_optimizer = torch.optim.AdamW(
            critic_trainable_params,
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )

        self.distillation_model.generator = accelerator.prepare(self.distillation_model.generator)
        self.distillation_model.fake_score = accelerator.prepare(self.distillation_model.fake_score)
        if self.distillation_model.discriminator is not None:
            self.distillation_model.discriminator = accelerator.prepare(self.distillation_model.discriminator)
        self.generator_optimizer = accelerator.prepare(self.generator_optimizer)
        self.critic_optimizer = accelerator.prepare(self.critic_optimizer)

        # Step 3: Initialize the dataloader

        self.backward_simulation = getattr(config, "backward_simulation", True)


        dataset = ImageVideoDataset(
            data_root=config.train_data_dir,
            dataset_file=config.train_data_meta,
            caption_column=config.caption_column,
            video_column=config.video_column,
            resolution_buckets=config.video_resolution_buckets,
            frame_interval=config.video_sample_stride,
        )
        sampler = BucketDistributedSampler(
            dataset, 
            config.batch_size, 
            accelerator.num_processes, 
            accelerator.process_index
        )
        dataloader = torch.utils.data.DataLoader(
            dataset, 
            batch_sampler=sampler, 
            num_workers=config.dataloader_num_workers,
            pin_memory=True,
        )
        self.dataloader = cycle(dataloader, sampler)

        self.step = 0
        self.global_step = 0
        self.avg_loss_dict = {}
        self.max_grad_norm = 1.0
        self.previous_time = None

        if config.generator_name == "wan":
            pipeline_cls = WanFunControlPipeline
        elif config.generator_name == "causal_wan":
            pipeline_cls = WanFunCausalControlPipeline
        self.pipeline = pipeline_cls(
            vae=self.distillation_model.vae, 
            text_encoder=self.distillation_model.text_encoder,
            tokenizer=self.distillation_model.tokenizer,
            transformer=accelerator.unwrap_model(self.distillation_model.generator),
            scheduler=self.distillation_model.scheduler,
            clip_image_encoder=self.distillation_model.clip_image_encoder,
        )
        self.pipeline.to(device=self.device, dtype=self.dtype)


    def train_one_step(self):
        self.distillation_model.train()

        TRAIN_GENERATOR = (self.step // self.config.gradient_accumulation_steps) % self.config.dfake_gen_update_ratio == 0
        # TRAIN_GAN = TRAIN_GENERATOR or (self.step // self.config.gradient_accumulation_steps + 1) % self.config.dfake_gen_update_ratio == 0
        TRAIN_GAN = True
        VISUALIZE = self.global_step % self.config.validation_iters == 0 and not self.config.no_visualize

        # Step 1: Get the next batch of text prompts
        batch = next(self.dataloader)
        text_prompts = batch["text"]
        # clean_latent = batch["ode_latent"][:, -1].to(
        #     device=self.device, dtype=self.dtype)
        pixel_values = batch["pixel_values"].to(device=self.device, dtype=self.dtype)
        clip_pixel_values = batch["clip_pixel_values"].to(device=self.device, dtype=self.dtype)
        control_pixel_values = batch["control_pixel_values"].to(device=self.device, dtype=self.dtype)
        ref_pixel_values = batch["ref_pixel_values"].to(device=self.device, dtype=self.dtype)
        
        with torch.no_grad():
            # This way is quicker when batch grows up
            def _batch_encode_vae(pixel_values):
                pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
                bs = 1
                new_pixel_values = []
                for i in range(0, pixel_values.shape[0], bs):
                    pixel_values_bs = pixel_values[i : i + bs]
                    pixel_values_bs = self.distillation_model.vae.encode(pixel_values_bs)[0]
                    pixel_values_bs = pixel_values_bs.sample()
                    new_pixel_values.append(pixel_values_bs)
                output = torch.cat(new_pixel_values, dim = 0)
                return output
            latents = _batch_encode_vae(pixel_values)
            clean_latent = latents

            # Encode inpaint latents.
            control_latents = _batch_encode_vae(control_pixel_values)
            ref_latents = _batch_encode_vae(ref_pixel_values)

            if self.config.train_mode == "control_ref_repeat":
                ref_latents_conv_in = ref_latents.repeat(1, 1, latents.size()[2], 1, 1)
            else:
                ref_latents_conv_in = torch.zeros_like(latents).to(ref_latents.device, ref_latents.dtype)
                ref_latents_conv_in[:, :, :1] = ref_latents
            
            # Make first frame control to zero
            control_latents[:, :, :1] = 0

            control_latents = torch.cat([control_latents, ref_latents_conv_in], dim = 1)

            clip_context = []
            for clip_pixel_value in clip_pixel_values:
                clip_image = Image.fromarray(np.uint8(clip_pixel_value.float().cpu().numpy()))
                clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(self.device, self.dtype)
                _clip_context = self.distillation_model.clip_image_encoder([clip_image[:, None, :, :]])

                zero_init_clip_in = np.random.choice([True, False], p=[0.1, 0.9])
                clip_context.append(_clip_context if not zero_init_clip_in else torch.zeros_like(_clip_context))
                
            clip_context = torch.cat(clip_context)


        batch_size = len(text_prompts)

        # 使用 self forcing 时，对长视频需要进行随机截取
        num_frames = clean_latent.shape[2]
        if self.config.backward_simulation and num_frames > self.config.num_last_frames_with_grad:
            keep_frame = torch.randint(self.config.num_last_frames_with_grad, num_frames + 1, (1,), device=self.device)
            dist.broadcast(keep_frame, src=0)
            keep_frame = keep_frame.item()
            clean_latent = clean_latent[:, :, :keep_frame, :, :]
            control_latents = control_latents[:, :, :keep_frame, :, :]
            print(f"rank {dist.get_rank()} keep {keep_frame} from {num_frames} frames (min {self.config.num_last_frames_with_grad} frames, max {num_frames})")

        clean_latent = rearrange(clean_latent, "b c f h w -> b f c h w")
        control_latents = rearrange(control_latents, "b c f h w -> b f c h w")
        image_or_video_shape = clean_latent.shape

        # Step 2: Extract the conditional infos
        with torch.no_grad():
            conditional_dict = self.distillation_model.encode_text(text_prompts)

            if not getattr(self, "unconditional_dict", None):
                unconditional_dict = self.distillation_model.encode_text([self.config.negative_prompt] * batch_size)
                unconditional_dict = {k: v.detach()
                                      for k, v in unconditional_dict.items()}
                self.unconditional_dict = unconditional_dict  # cache the unconditional_dict
            else:
                unconditional_dict = self.unconditional_dict

        # Step 3: Train the generator
        with self.accelerator.accumulate(self.distillation_model.generator, self.distillation_model.fake_score, self.distillation_model.discriminator):
            generator_loss = 0.0
            if TRAIN_GENERATOR:
                critic_loss_dict, critic_log_dict = {}, {}

                generator_loss_dict, generator_log_dict = self.distillation_model.generator_loss(
                    image_or_video_shape=image_or_video_shape,
                    conditional_dict=conditional_dict,
                    unconditional_dict=unconditional_dict,
                    clean_latent=clean_latent,
                    inpaint_latents=control_latents,
                    clip_context=clip_context,
                )

                generator_loss += generator_loss_dict["dmd_loss"] * self.config.dmd_loss_weight

                if self.config.adv_g_loss_weight > 0 and TRAIN_GAN:
                    generator_loss += generator_loss_dict["gan_g_loss"] * self.config.adv_g_loss_weight

                for k, v in generator_loss_dict.items():
                    self.avg_loss_dict[k] = self.avg_loss_dict.get(k, 0.0) + v.item() / self.config.gradient_accumulation_steps
                generator_loss = generator_loss / self.config.gradient_accumulation_steps
                # self.generator_optimizer.zero_grad()
                self.accelerator.backward(generator_loss)
                self.critic_optimizer.zero_grad()
                if self.accelerator.sync_gradients:
                    generator_grad_norm = self.accelerator.clip_grad_norm_(self.distillation_model.generator.parameters(), self.max_grad_norm)
                    self.generator_optimizer.step()

                    self.generator_optimizer.zero_grad()  # clean up gan gradient

                    if self.config.use_ema:
                        self.ema_generator.step(self.distillation_model.generator.parameters())
                    
                    critic_grad_norm = None
            else:
                generator_log_dict = {}

                # Step 4: Train the critic
                critic_loss = 0.0
                critic_loss_dict, critic_log_dict = self.distillation_model.critic_loss(
                    image_or_video_shape=image_or_video_shape,
                    conditional_dict=conditional_dict,
                    unconditional_dict=unconditional_dict,
                    clean_latent=clean_latent,
                    inpaint_latents=control_latents,
                    clip_context=clip_context,
                )

                critic_loss += critic_loss_dict["critic_loss"]

                if self.config.adv_d_loss_weight > 0 and TRAIN_GAN:
                    critic_loss += critic_loss_dict["gan_d_loss"] * self.config.adv_d_loss_weight

                for k, v in critic_loss_dict.items():
                    self.avg_loss_dict[k] = self.avg_loss_dict.get(k, 0.0) + v.item() / self.config.gradient_accumulation_steps
                critic_loss = critic_loss / self.config.gradient_accumulation_steps
                # self.critic_optimizer.zero_grad()
                self.accelerator.backward(critic_loss)
                self.generator_optimizer.zero_grad() 
                if self.accelerator.sync_gradients:
                    critic_trainable_params = list(self.distillation_model.fake_score.parameters())
                    if self.distillation_model.discriminator is not None:
                        critic_trainable_params += list(self.distillation_model.discriminator.parameters())
                    critic_grad_norm = self.accelerator.clip_grad_norm_(critic_trainable_params, self.max_grad_norm)
                    self.critic_optimizer.step()

                    self.critic_optimizer.zero_grad()
                    generator_grad_norm = None

        # Step 5: Logging
        if self.accelerator.is_main_process and self.accelerator.sync_gradients:
            if VISUALIZE:
                input_dict = {
                    "original_pixel_values": (pixel_values / 2 + 0.5).clamp(0, 1),
                    "clean_latent": clean_latent,
                    "control_latents": control_latents[:,:,:16],
                }
                self.add_visualization(generator_log_dict, critic_log_dict, input_dict)

            self.writer.add_scalars("loss", self.avg_loss_dict, self.global_step)
            self.avg_loss_dict = {}  # reset avg_loss_dict for next step
        
        loss_dict = {}
        for k, v in critic_loss_dict.items():
            loss_dict[k] = v.item()
        if TRAIN_GENERATOR:
            for k, v in generator_loss_dict.items():
                loss_dict[k] = v.item()
        return loss_dict

    def add_visualization(self, generator_log_dict, critic_log_dict, input_dict):
        visual_dict = {}
        visual_dict.update({
            "original_pixel_values": input_dict["original_pixel_values"],
            "input_clean_latent": self.pipeline.decode_latents(input_dict["clean_latent"].permute(0,2,1,3,4))[0].permute(0,2,1,3,4),
            "input_control_latents": self.pipeline.decode_latents(input_dict["control_latents"].permute(0,2,1,3,4))[0].permute(0,2,1,3,4),
        })

        if critic_log_dict:
            critictrain_latent, critictrain_noisy_latent, critictrain_pred_image = map(
                lambda x: self.pipeline.decode_latents(x.permute(0,2,1,3,4))[0].permute(0,2,1,3,4),
                [critic_log_dict['critictrain_latent'], critic_log_dict['critictrain_noisy_latent'],
                    critic_log_dict['critictrain_pred_image']]
            )
            visual_dict.update({
                "critictrain_latent": critictrain_latent,
                "critictrain_noisy_latent": critictrain_noisy_latent,
                "critictrain_pred_image": critictrain_pred_image
            })

        if "dmdtrain_clean_latent" in generator_log_dict:
            (dmdtrain_generator_noisy_input, dmdtrain_clean_latent, dmdtrain_noisy_latent, dmdtrain_pred_real_image, dmdtrain_pred_fake_image) = map(
                lambda x: self.pipeline.decode_latents(x.permute(0,2,1,3,4))[0].permute(0,2,1,3,4) if x is not None else None,
                [generator_log_dict.get('generator_noisy_input', None), generator_log_dict['dmdtrain_clean_latent'], generator_log_dict['dmdtrain_noisy_latent'],
                    generator_log_dict['dmdtrain_pred_real_image'], generator_log_dict['dmdtrain_pred_fake_image']]
            )

            visual_dict.update(
                {
                    "dmdtrain_generator_noisy_input": dmdtrain_generator_noisy_input,
                    "dmdtrain_clean_latent": dmdtrain_clean_latent,
                    "dmdtrain_noisy_latent": dmdtrain_noisy_latent,
                    "dmdtrain_pred_real_image": dmdtrain_pred_real_image,
                    "dmdtrain_pred_fake_image": dmdtrain_pred_fake_image
                }
            )
        # wandb_loss_dict.update(visual_dict)
        for k, v in visual_dict.items():
            if v is None:
                continue
            for i in range(v.shape[0]):
                video = v[i].permute(0, 2, 3, 1).float().cpu().numpy()
                save_path = os.path.join(self.config.output_dir, "visualization", f"step_{self.global_step}_{k}_{i:03d}.mp4")
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                export_to_video(video, save_path, fps=8)

    def train(self):
        self.validate()
        # self.save()

        progress_bar = tqdm(
            range(0, self.config.max_train_steps),
            initial=0,
            desc="Steps",
            # Only show the progress bar once on each machine.
            disable=not self.accelerator.is_local_main_process,
        )
        while True:
            loss_logs = self.train_one_step()
            if (not self.config.no_save) and (self.global_step + 1) % self.config.checkpoint_iters == 0:
                self.save()
                torch.cuda.empty_cache()
            
            if (self.global_step + 1) % self.config.validation_iters == 0:
                self.validate()

            self.step += 1
            if self.step % self.config.gradient_accumulation_steps == 0:
                self.global_step += 1
                progress_bar.update(1)
            progress_bar.set_postfix(loss_logs)
            # self.accelerator.log(loss_logs, step=self.step)

            if self.global_step >= self.config.max_train_steps:
                break
        if self.global_step % self.config.checkpoint_iters != 0:
            self.save()
            self.validate()
    
    def save(self):
        if self.accelerator.is_main_process:
            if self.config.checkpoints_total_limit is not None:
                checkpoints = os.listdir(self.config.output_dir)
                checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                if len(checkpoints) >= self.config.checkpoints_total_limit:
                    num_to_remove = len(checkpoints) - self.config.checkpoints_total_limit + 1
                    removing_checkpoints = checkpoints[0:num_to_remove]

                    logging.info(
                        f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                    )
                    logging.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                    for removing_checkpoint in removing_checkpoints:
                        removing_checkpoint = os.path.join(self.config.output_dir, removing_checkpoint)
                        shutil.rmtree(removing_checkpoint, ignore_errors=True)

            output_dir = os.path.join(self.config.output_dir, f"checkpoint-{self.global_step+1}")
            os.makedirs(output_dir, exist_ok=True)
            generator = self.accelerator.unwrap_model(self.distillation_model.generator)
            generator.save_pretrained(os.path.join(output_dir, "generator"))
            print(f"Save generator model to {os.path.join(output_dir, 'generator')}")
            fake_score = self.accelerator.unwrap_model(self.distillation_model.fake_score)
            fake_score.save_pretrained(os.path.join(output_dir, "fake_score"))
            print(f"Save fake score model to {os.path.join(output_dir, 'fake_score')}")
            if self.distillation_model.discriminator is not None:
                discriminator = self.accelerator.unwrap_model(self.distillation_model.discriminator)
                discriminator.save_pretrained(os.path.join(output_dir, "discriminator"))
            if self.config.use_ema:
                self.ema_generator.save_pretrained(os.path.join(output_dir, "ema_generator"))
                print(f"Save ema generator model to {os.path.join(output_dir, 'ema_generator')}")
        self.accelerator.wait_for_everyone()

    @torch.no_grad()
    def validate(self):
        if self.accelerator.is_main_process:
            if self.config.use_ema:
                self.ema_generator.store(self.distillation_model.generator.parameters())
                self.ema_generator.copy_to(self.distillation_model.generator.parameters())

            print(f"Start validation at step {self.global_step+1}...")
            validation_samples = json.load(open(self.config.validation_data_meta))

            if self.config.seed is None:
                generator = None
            else:
                generator = torch.Generator(device=self.device).manual_seed(self.config.seed)

            if isinstance(self.config.video_resolution_buckets[0], str):
                self.config.video_resolution_buckets = [eval(bucket) for bucket in self.config.video_resolution_buckets]

            for index, sample in enumerate(validation_samples):
                prompt = validation_samples[index]["prompt"]
                control_video = validation_samples[index]["control_video"]
                video_length = validation_samples[index]["length"]
                fps = validation_samples[index].get("fps", 8)
                ref_image = validation_samples[index]["image_start"]
                width, height = Image.open(ref_image).size
                sorted_buckets = sorted(self.config.video_resolution_buckets, key=lambda x: (abs(x[1]/x[2]-height/width), abs(x[1]-height)+abs(x[2]-width)))
                sample_size = sorted_buckets[0][1:3]

                video_length = int(video_length // self.distillation_model.vae.config.temporal_compression_ratio * self.distillation_model.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
                input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=ref_image)
                video = self.pipeline(
                    prompt, 
                    num_frames = video_length,
                    negative_prompt = "色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走",
                    height      = sample_size[0],
                    width       = sample_size[1],
                    guidance_scale = self.config.guidance_scale,
                    generator   = generator,

                    control_video = input_video,
                    ref_image = ref_image,
                    clip_image = clip_image,
                    mode = self.config.train_mode,
                    timesteps = self.config.denoising_step_list if self.config.generator_name == "wan" else None,
                    context_perturbation = self.config.context_perturbation,
                ).videos[0]
                video = video.permute(1, 2, 3, 0).cpu().numpy()  # (c f h w) -> (f h w c)

                save_path = os.path.join(self.config.output_dir, "validation", f"step_{self.global_step+1}_output_{index:03d}.mp4")
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                export_to_video(video, save_path, fps=8)
            
            if self.config.use_ema:
                self.ema_generator.restore(self.distillation_model.generator.parameters())
            
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        self.accelerator.wait_for_everyone()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--no_save", action="store_true")
    parser.add_argument("--no_visualize", action="store_true")

    args = parser.parse_args()

    config = OmegaConf.load(args.config_path)
    config.no_save = args.no_save
    config.no_visualize = args.no_visualize

    trainer = Trainer(config)
    trainer.train()

if __name__ == "__main__":
    main()
