import os.path
from models import StableDiffision
from seg_dataset import BucketDataset, MaskBucketDataset
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch_frame import LoggerHook, AccelerateTrainer, logger
from torch_frame.trainer import ProgressBar
from torch_frame.hooks import EvalHook, HookBase
import yaml
import torch
import argparse
import numpy as np
from tqdm import tqdm


class ShuffleBucketHook(HookBase):
    """Hook for shuffling buckets in the dataset during training.

    This hook is designed to integrate with the training process and shuffle
    the data buckets at appropriate intervals to improve training effectiveness.
    """
    def __init__(self, dataset):
        self.dataset = dataset





def parse_args():
    """Parses command-line arguments for the training script.

    Configures and reads command-line arguments to set up the training environment.
    Currently supports specifying a configuration file path.

    Returns:
        argparse.Namespace: Parsed command-line arguments with configuration settings.
    """
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--config",
        type=str,
        default="config/sdv1.yaml",
        required=False,
        help="Path to config",
    )
    args = parser.parse_args()
    return args



def eval_map(model, batch):
    """Evaluates a model on a single batch and returns computed metrics.

    Moves batch data to the model's device, runs evaluation, and collects
    loss and mIoU metrics for the given batch.

    Args:
        model: Model instance to evaluate.
        batch: Dictionary containing batch data for evaluation.

    Returns:
        dict: Dictionary containing loss and mIoU metrics for the evaluated batch.
    """
    for k in batch:
        v = batch[k]
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(model.device)
    loss, metric = model.evaluate(batch)
    metric["loss"] = [loss.cpu().item()]
    metric["miou"] = [metric["miou"]]
    return metric



def collection(data):
    """Collates a list of data dictionaries into a single batch dictionary.

    Takes a list of data entries (each being a dictionary) and combines them
    into a single batch dictionary where each key contains a stacked tensor
    or list of values from all entries.

    Args:
        data: List of data dictionaries to collate into a batch.

    Returns:
        dict: Batched data with stacked tensors where applicable.
    """
    if isinstance(data, list) and isinstance(data[0], list):
        data = data[0]
    key_list = list(data[0].keys())
    data_batch = {}
    for k in key_list:
        vs = [d[k] for d in data]
        if isinstance(vs[0], torch.Tensor):
            vs = torch.stack(vs, 0)
        data_batch[k] = vs
    return data_batch


class EvalLoraHook(EvalHook):
    """Custom evaluation hook for LoRA-enhanced models.

    Extends the base EvalHook to provide specialized evaluation functionality
    for models using LoRA (Low-Rank Adaptation) weights, including proper
    model saving and metric tracking.
    """
    def before_train(self) -> None:
        """Prepares the dataloader for evaluation before training starts.

        Uses the trainer's accelerator to prepare the dataloader for
        distributed evaluation if needed.
        """
        self.dataloader = self.trainer.accelerator.prepare(self.dataloader)

    @torch.no_grad()
    def _do_eval(self):
        """Performs evaluation on the dataset and aggregates results.

        Iterates over the evaluation dataloader, computes metrics for each batch,
        aggregates results across distributed processes, and logs the metrics.
        """
        tot_res = {}
        pbar = ProgressBar(total=len(self.dataloader), desc=f"eval", ascii=True)
        for batch in self.dataloader:
            res = self._eval_func(self.trainer.model_or_module, batch)

            for k, v in res.items():
                t = torch.tensor(v, device=self.trainer.accelerator.device)
                result = self.trainer.accelerator.gather(t)
                v_gather = result.cpu().numpy().tolist()
                # if self.trainer.accelerator.is_main_process:
                #     print(f"{self.trainer.accelerator.device}, v_gather: {v_gather}")
                tot_res.setdefault(k, []).extend(v_gather)
            pbar.update(1)
        if tot_res and self.trainer.accelerator.is_main_process:
            rename_res = {self.prefix + k: np.mean(v) for k, v in tot_res.items()}
            self.log(self.trainer.epoch, **rename_res, smooth=False, window_size=1)

    def save_model(self):
        """Saves model weights based on evaluation results.

        Handles model checkpoint saving, keeping only the specified number
        of recent checkpoints and saving the best model based on the specified
        evaluation metric.
        """
        if not self.trainer.accelerator.is_main_process:
            return
        if self._max_to_keep is not None and self._max_to_keep >= 1:
            epoch = self.trainer.epoch  # ranged in [0, max_epochs - 1]
            self.trainer.model_or_module.save_weight(epoch, self.trainer.ckpt_dir)

        if self.save_metric is not None:
            if not self.is_better(self.trainer.metric_storage[self.save_metric]):
                return
            self.cur_best = self.trainer.metric_storage[self.save_metric].avg
            logger.info(f"{self.save_metric} update to {round(self.cur_best, 4)}")
            self.trainer.model_or_module.save_weight("best", self.trainer.ckpt_dir)



def main():
    """Main entry point for the training script.

    Orchestrates the training process by parsing configuration, initializing the model,
    setting up datasets and dataloaders, configuring the optimizer and training hooks,
    and starting the training loop.
    """
    args = parse_args()
    with open(args.config) as f:
        config = yaml.safe_load(f)

    config_model = config["model"]
    config_diffusion = config_model["diffusion"]
    config_lora = config_model.get("lora", None)
    config_train = config["train"]

    if not os.path.exists(config_diffusion["pretrained_model_name_or_path"]):
        config_diffusion["pretrained_model_name_or_path"] = config_diffusion["pretrained_model_name_or_path_backup"]

    model = StableDiffision(config_diffusion,
                            config_lora=config_lora)

    train_dataset = MaskBucketDataset(
        config_train["dataset_path"],
        config_train["train_batch_size"],
        dataset_type="train",
        only_word=config_train.get("only_word", True),
        max_num_data=config_train.get("max_num_data", 8000000),
        use_refine=config_train.get("use_sam_refine", False),
    )

    if torch.cuda.device_count() <= 1:
        args = {"shuffle": True, "batch_size": None}
    else:
        args = {
            # "sampler": BucketDistributedSampler(len(train_dataset), torch.cuda.device_count()),
            "shuffle": True,
            "batch_size": 1,
            "drop_last": True,
        }
    train_loader = DataLoader(train_dataset,
                              persistent_workers=True,
                              num_workers=config_train["num_workers"],
                              collate_fn=collection,
                              **args
                              )

    valid_dataset = BucketDataset(
        config_train["dataset_path"],
        4,
        dataset_type="valid",
    )
    valid_loader = DataLoader(valid_dataset,
                              shuffle=False,
                              persistent_workers=False,
                              batch_size=1,
                              num_workers=0,
                              collate_fn=collection,
                              drop_last=True,
                              )

    if config_train["use_8bit_adam"]:
        print("使用8bit")
        import bitsandbytes as bnb
        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        model.trainable_params,
        lr=float(config_train["learning_rate"]),
        betas=(0.9, 0.999),
        weight_decay=1e-5,
        eps=1e-08,
    )

    hooks = [
        ShuffleBucketHook(train_dataset),
        EvalLoraHook(
            valid_loader, eval_map, max_to_keep=int(config_train["max_to_keep"]), save_metric="miou", max_first=True,
            prefix="valid"
        ),
    ]
    if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
        hooks.append(LoggerHook())

    scheduler = config_train["lr_scheduler"]

    trainer = AccelerateTrainer(model, optimizer, scheduler, train_loader, config_train["max_epoch"],
                                config_train["workspace"], config_train["max_grad_norm"],
                                mixed_precision=config_train["mixed_precision"], hooks=hooks,
                                gradient_accumulation_steps=config_train["gradient_accumulation_steps"],
                                warmup_iters=config_train["warmup_iters"],
                                warmup_factor=config_train["warmup_factor"],
                                hook_only_main_gpu=False
                                )
    trainer.log_param(**config)
    trainer.train()


if __name__ == '__main__':
    main()
