"""Script for a pretraining run."""

import torch
import hydra

import os
import time
import datetime
import logging
from collections import defaultdict

import cramming

import itertools

from flash_attn import linear
import collections
import math

log = logging.getLogger(__name__)

class UnboundedHistogram:
    def __init__(self, max_history):
        self.max_history =  max_history
        self.history = collections.deque(maxlen=self.max_history)

    def append(self, value):
        self.history.append(value)

    def get_count(self, it, score):
        count = 0
        for i in it:
            if i < score:
                count += 1
        return count

    def percentile_of_score(self, score):
        num_lower_scores = self.get_count(self.history, score)
        return num_lower_scores * 100. / len(self.history)

class LossSelector(object):
    def __init__(self, device, sampling_min, history_length, beta):
        self.device = device
        self.historical_losses = UnboundedHistogram(history_length) #collections.deque(maxlen=history_length)
        self.sampling_min = sampling_min
        self.beta = beta

    def update_history(self, losses):
        for loss in losses:
            self.historical_losses.append(loss.item())

    def calculate_probability(self, loss):
        percentile = self.historical_losses.percentile_of_score(loss)
        return math.pow(percentile / 100., self.beta)

    def get_probability(self, losses):
        self.update_history(losses)
        probs = torch.tensor([max(self.sampling_min, self.calculate_probability(loss.item())) for loss in losses], device=self.device, dtype=torch.float32)
        return probs
    
    def get_selected_loss(self, losses):
        probs = self.get_probability(losses)
        mask = torch.rand_like(probs) < probs
        selected_loss = (losses[mask]).mean()
        return selected_loss


def main_training_process(cfg, setup):
    """This function controls the central training loop."""
    local_time = time.time()
    model = cramming.construct_model(cfg.arch, cfg.data.vocab_size)
    dataset, tokenizer = cramming.load_pretraining_corpus(cfg.data, cfg.impl)

    model_engine, _, _, dataloader = cramming.load_backend(
        model,
        dataset,
        tokenizer,
        cfg.train,
        cfg.impl,
        setup=setup,
    )
    model_engine.train(cfg.train.pretrain_in_train_mode)
    stats = defaultdict(list)

    # Start the clocks now:
    wallclock_timer = time.time()
    train_time = time.time()  # Crude time measurement for print_loss_every_nth_step
    training_allowed = True
    loss_vals = []

    iterable_data = enumerate(dataloader)
    if cfg.train.gradinit.enabled:
        model_engine.gradinit(iterable_data, cfg.train.optim, cfg.train.gradinit)

    ratio_avg = 0
    ratio_N = 0
    loss_selector = LossSelector(torch.device('cuda:0'), 1e-9, 1024, 0.4786)
    # Launch training
    for step, batch in iterable_data:
        # Heavy lifting is moved to engines
        device_batch = model_engine.to_device(batch)
        losses = model_engine.forward(**device_batch)["loss"]
        loss = losses.mean()
        selected_loss = loss_selector.get_selected_loss(losses)
        model_engine.backward(selected_loss)
        model_engine.optimizer_step()
        loss_vals.append(loss.detach())

        # device_batch = model_engine.to_device(batch)
        # loss = model_engine.step(device_batch)
        # loss_vals.append(loss.detach())


        # if step % 500 == 0:
        #     linear.test = True
        #     linear.prepare(model_engine.setup["device"])
        #     model_engine.eval()
        #     org_state = torch.get_rng_state().clone()
        #     for i, batch in itertools.islice(iterable_data, 2):
        #         device_batch = model_engine.to_device(batch)
        #         cramming.utils.set_random_seed(i)
        #         loss = model_engine.forward(**device_batch)["loss"]
        #         (loss / model_engine.accumulation_steps_expected).backward()
        #         model_engine.optimizer.zero_grad()
        #     linear.test = False

        #     linear.cal_var()
        #     linear.update_weight_ratio()
        #     linear.reset_dict()

        #     torch.set_rng_state(org_state)

        #     S = linear.S
        #     weight_ratio_dict = linear.weight_ratio_dict
        #     ratio = (2 + sum([sum(weight_ratio_dict[i] for i in range(4 * j, 4 * j + 4)) / 4 for j in range(16)]) / 16) / 3
        #     ratio_avg = (ratio_avg * ratio_N + ratio) / (ratio_N + 1)
        #     logging.info(f"ratio: {ratio}, ratio_avg: {ratio_avg}")
        #     ratio_N += 1
        #     sample_metric = {"step": [step], "S": [S], "weight_ratio[0]": [weight_ratio_dict[0]], "weight_ratio[1]": [weight_ratio_dict[1]], "weight_ratio[2]": [weight_ratio_dict[2]], "weight_ratio[3]": [weight_ratio_dict[3]], "ratio": [ratio], "ratio_avg": [ratio_avg]}
        #     # sample_metric = {"step": [step], "sgd_var": [sgd_var], }
        #     cramming.utils.wandb_log(sample_metric, cfg)
        #     model_engine.train(cfg.train.pretrain_in_train_mode)

        # Check stopping criteria
        if cfg.train.scheduler == "same-as-baseline":
            if step == cfg.train.baseline_steps:
                training_allowed = False
                log.info("Reached baseline steps. Stopping training ...")
        else:
            if check_deadline(wallclock_timer, cfg.budget) or step == cfg.train.steps:
                training_allowed = False
                log.info("Reached deadline. Stopping training ...")

        # Collect stats and print to console and upload to wandb
        if step % cfg.impl.print_loss_every_nth_step == 0:
            loss_vals, train_time = collect_stats(step, loss_vals, train_time, stats, model_engine, dataloader, cfg)
            if check_early_termination(wallclock_timer, stats["loss"][-1], cfg.impl.early_termination):
                training_allowed = False
                log.info("Loss higher than allowed threshold. Stopping training early...")

        # Checkpointing is triggered from stopping criteria and normal intervals
        if cfg.impl.save_intermediate_checkpoints and step % cfg.impl.save_every_nth_step == 0:
            # state = dict(step=step, tokenizer_name=tokenizer.name)
            state = dict(step=step)
            checkpoint_id = loss.item()
            if cramming.utils.is_main_process():
                model_engine.save_training_checkpoint(checkpoint_id, state=state)

        if not loss.detach().isfinite():
            training_allowed = False
            log.info("Ending training due to non-finite loss.")

        flag_communication(training_allowed)

        if (cfg.dryrun and step > 2) or not training_allowed:
            break

    if cramming.utils.is_main_process():
        # Save to summary:
        metrics = dict(num_params=sum([p.numel() for p in model.parameters()]))
        cramming.utils.save_summary("pretrain", cfg, metrics, stats, time.time() - local_time, setup)
        # Save final checkpoint:
        now = datetime.datetime.now()
        checkpoint_id = f"{''.join(cfg.arch.architectures)}_{now.strftime('%Y-%m-%d')}_{loss:2.4f}"
        model_engine.save_final_model(os.path.join(cfg.base_dir, cfg.name), checkpoint_id, tokenizer, cfg.arch, cfg.dryrun)


def check_deadline(launch_time, hour_limit):
    """These measurements are deliberately wall-clock based."""
    current_time = time.time()
    return True if (current_time - launch_time) / 3600 > hour_limit else False


def check_early_termination(launch_time, loss, early_termination):
    """Early termination based on terrible loss."""
    if early_termination.enabled and loss > early_termination.loss_threshold:
        current_time = time.time()
        return True if (current_time - launch_time) / 3600 > early_termination.budget else False
    else:
        return False


def collect_stats(step, loss_vals, train_time, stats, model_engine, dataloader, cfg):
    stats["step"] += [step]
    stats["epoch"] += [dataloader.epoch_counter]

    tokens_per_step = cramming.utils.num_processes() * model_engine.record_tokens_per_step()
    stats["tokens"] += [step * tokens_per_step]
    stats["loss"] += [torch.stack(loss_vals).mean().item()]  # Averaged loss

    current_lr = model_engine.optimizer.param_groups[0]["lr"]
    log_msg = f"Train loss {loss_vals[-1].item():2.4f} at step {step} with lr {current_lr:.5f}. "
    log_msg += f"[Avg: {stats['loss'][-1]:2.4f}] "
    if step > 0:
        stats["train_time"] += [(time.time() - train_time) / cfg.impl.print_loss_every_nth_step]
        estimated_train_finish = str(datetime.timedelta(seconds=stats["train_time"][-1] * cfg.train.steps))
        tokens_per_second = tokens_per_step / stats["train_time"][-1]
        stats["tok/sec"] += [int(tokens_per_second)]
        log_msg += f" Perf: {stats['train_time'][-1]:2.4f}s per step ({tokens_per_second:.0f}t/s). "
        log_msg += f"Estimated Total Train: {estimated_train_finish}."

    # Adaptive optim stats
    stats["lr"] += [current_lr]
    stats["batch_size"] += [model_engine.record_batch_size()]
    stats["seq_length"] = [model_engine.current_seq_length]

    # Publish
    cramming.utils.wandb_log(stats, cfg)
    log.info(log_msg)

    # Clear:
    loss_vals = []
    train_time = time.time()
    return loss_vals, train_time


def flag_communication(training_allowed):
    """A quick and dirty communication through NCCL. Should not be a major burden."""
    if torch.distributed.is_initialized():
        comm_tensor = torch.as_tensor(training_allowed).cuda()
        torch.distributed.all_reduce(comm_tensor, torch.distributed.ReduceOp.MIN, async_op=False)
        if comm_tensor >= 1:
            return True
        else:
            return False
    else:
        return training_allowed


@hydra.main(config_path="cramming/config", config_name="cfg_pretrain", version_base="1.1")
def launch(cfg):
    cramming.utils.main_launcher(cfg, main_training_process, job_name="pretraining")


if __name__ == "__main__":
    launch()
