import os
import time

from meshflow.platform import get_backend


class MFTimer:

    def __init__(self, func, trials=3, warmup_trials=3, in_ms=True, device=None) -> None:
        self.func = func
        self.warmup_trials = warmup_trials
        self.trials = trials
        self.in_ms = in_ms

        self.device = device
        if self.device == None:
            self.device = os.environ.get("MESHFLOW_DEVICE", "cpu")

        self.backend = get_backend()

    def time(self):
        if self.backend == "jax":
            return self.time_jax()
        elif self.backend == "torch":
            if self.device == "cuda":
                return self.time_torch_cuda()
            elif self.device == "cpu":
                return self.time_cpu()
        return None

    def time_cpu(self):
        for _ in range(self.warmup_trials):
            self.func()

        start_t = time.perf_counter()
        for _ in range(self.trials):
            self.func()

        elapsed_time_ = time.perf_counter() - start_t
        elapsed_time_ = elapsed_time_ / self.trials

        # time elapsed in **milliseconds**
        if self.in_ms:
            return elapsed_time_ * 1000
        return elapsed_time_

    def time_torch_cuda(self):
        import torch

        start_evt_ = []
        end_evt_ = []
        for _ in range(0, self.trials):
            start_evt_.append(torch.cuda.Event(enable_timing=True))
            end_evt_.append(torch.cuda.Event(enable_timing=True))

        for trial_idx_ in range(0, self.trials + self.warmup_trials):
            evt_idx = trial_idx_ - self.warmup_trials

            if evt_idx >= 0:
                start_evt_[evt_idx].record()

            self.func()

            if evt_idx >= 0:
                end_evt_[evt_idx].record()

        torch.cuda.synchronize()
        ops_elapsed_time_ = 0
        for evt_idx in range(0, self.trials):
            # time elapsed in **milliseconds**
            ops_elapsed_time_ += start_evt_[evt_idx].elapsed_time(end_evt_[evt_idx])
        ops_elapsed_time_ = ops_elapsed_time_ / self.trials

        if self.in_ms:
            return ops_elapsed_time_
        return ops_elapsed_time_ / 1000

    def time_jax(self):
        import jax
        for _ in range(self.warmup_trials):
            self.func()
            (jax.device_put(0.) + 0).block_until_ready()

        start_t = time.perf_counter()
        for _ in range(self.trials):
            self.func()
            (jax.device_put(0.) + 0).block_until_ready()

        elapsed_time_ = time.perf_counter() - start_t
        elapsed_time_ = elapsed_time_ / self.trials

        # time elapsed in **milliseconds**
        if self.in_ms:
            return elapsed_time_ * 1000
        return elapsed_time_
