# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import time
from typing import List, Union

import json

from swift.llm import SamplingArguments, SwiftPipeline, load_dataset
from swift.utils import get_logger

logger = get_logger()


class SwiftSampling(SwiftPipeline):
    args_class = SamplingArguments
    args: args_class

    def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> None:
        super().__init__(args)
        self.args.save_args()
        os.makedirs(self.args.output_dir, exist_ok=True)
        self.cur_piece = 0
        self.total_piece = 1

        if self.args.data_range:
            self.cur_piece, self.total_piece = self.args.data_range

        if self.args.sampler_type == "sample":
            from swift.llm.sampling.vanilla_sampler import VanillaSampler

            self.sampler = VanillaSampler(self.args)
        elif self.args.sampler_type == "mcts":
            from swift.llm.sampling.mcts import MctsSampler

            self.sampler = MctsSampler(self.args)
        elif self.args.sampler_type == "distill":
            from swift.llm.sampling.distill_sampler import DistillSampler

            self.sampler = DistillSampler(self.args)
        else:
            raise ValueError(f"Unsupported sampler type: {self.args.sampler_type}")

    def _get_dataset(self):
        args = self.args
        dataset_kwargs = args.get_dataset_kwargs()
        sampling_dataset, _ = load_dataset(
            args.dataset,
            split_dataset_ratio=0.0,
            shuffle=args.dataset_shuffle,
            **dataset_kwargs,
        )
        logger.info(f"Sampling_dataset: {sampling_dataset}")
        dataset_len = len(sampling_dataset)
        piece_len = dataset_len // self.total_piece
        sampling_dataset = sampling_dataset.select(
            range(piece_len * self.cur_piece, piece_len * (self.cur_piece + 1))
        )
        return sampling_dataset

    def run(self):
        os.makedirs(self.args.output_dir, exist_ok=True)
        iter_file = os.path.join(self.args.output_dir, self.args.output_file)
        resume_file = os.path.join(
            self.args.output_dir, self.args.output_file + ".resume"
        )
        tmp_file = os.path.join(self.args.output_dir, self.args.output_file + ".tmp")
        ckpt_state_file = os.path.join(self.args.output_dir, "ckpt_state.json")
        if os.path.exists(iter_file) and not self.args.override_exist_file:
            return

        index_resume = -1
        write_mode = "w"
        if self.args.resume:
            write_mode = "a"
            if os.path.exists(resume_file):
                shutil.copyfile(resume_file, tmp_file)

            if os.path.exists(ckpt_state_file):
                with open(ckpt_state_file, "r") as ckpt_state:
                    data = json.load(ckpt_state)
                    index_resume = data.get("index", -1)
                    logger.info(f"Loaded index_resume: {index_resume}")
        else:
            if os.path.exists(tmp_file):
                os.remove(tmp_file)

        dataset = self._get_dataset()
        dataset_len = len(dataset)
        total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size)

        if (
            self.args.num_sampling_per_gpu_batches is None
            or self.args.num_sampling_per_gpu_batches > total_iters
        ):
            self.args.num_sampling_per_gpu_batches = total_iters

        with open(tmp_file, write_mode) as f:
            for _index in range(self.args.num_sampling_per_gpu_batches):
                if _index <= index_resume:
                    continue
                logger.info(f" Sampling index:{_index}")
                slices = dataset[
                    self.args.num_sampling_per_gpu_batch_size
                    * _index : self.args.num_sampling_per_gpu_batch_size
                    * (_index + 1)
                ]
                slices = self.sampler.truncate_input(slices)
                generated = self.sampler.do_sample(slices)
                f.writelines(generated)
                f.flush()
                shutil.copy(tmp_file, resume_file)
                with open(ckpt_state_file, "w") as ckpt_state:
                    json.dump({"index": _index}, ckpt_state)

        if os.path.exists(iter_file):
            shutil.move(iter_file, iter_file + "." + str(int(time.time())))
        shutil.move(resume_file, iter_file)
        logger.info(f"Sample file {iter_file} generated.")


def sampling_main(args: Union[List[str], SamplingArguments, None] = None):
    return SwiftSampling(args).main()
