import math

import jax
import jax.numpy as jnp
import mlxu
import numpy as np
from absl import flags
from clu import metric_writers

FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
    log_ppl_n_buckets=8,
)


def create_logger(log_dir, enable):
    return metric_writers.create_default_writer(log_dir, just_logging=not enable)


def tf_arr_to_string(tf_array):
    res = []
    if len(tf_array.shape) == 1:
        for elem in tf_array:
            res.append(str(elem))
    else:
        for elem in tf_array:
            res.append(tf_arr_to_string(elem))

    return "[" + ",".join(res) + "]"


def fill_missing_dict_keys_with_zero(list_of_dicts):
    # First, find all unique keys in all dictionaries
    all_keys = set().union(*[d.keys() for d in list_of_dicts])

    # Then, iterate over the dictionaries and fill in missing keys with 0
    filled_dicts = [{**{key: 0 for key in all_keys}, **d} for d in list_of_dicts]

    return filled_dicts


def metrics_assign_group(metrics_dict, group, index=0):
    result = {}
    for k, v in metrics_dict.items():
        groups = k.split("/")
        abs_index = index % len(groups)
        groups = groups[:abs_index] + [group] + groups[abs_index:]
        new_k = "/".join(groups)
        result[new_k] = v
    return result


class LogAggregator:
    def __init__(
        self,
        keep_last=-1,
        provide_mean=True,
        provide_latest=True,
        reset_on_get=True,
        device_get_at_add=True,
    ):
        assert keep_last == -1 or keep_last > 0
        self.keep_last = keep_last
        self.provide_mean = provide_mean
        self.provide_latest = provide_latest
        self.reset_on_get = reset_on_get
        self.device_get_at_add = device_get_at_add
        self.logs = []

        assert keep_last != -1 or reset_on_get  # otherwise memory will increase

    def add(self, new_logs):
        self.add_list([jax.device_get(new_logs)])

    def add_list(self, new_logs):
        if self.device_get_at_add:
            new_logs = jax.device_get(new_logs)
        self.logs += new_logs
        if self.keep_last != -1 and len(self.logs) > self.keep_last:
            self.logs = self.logs[-self.keep_last :]

    def get_logs(self):
        metrics = {}

        if len(self.logs) != 0:
            if not self.device_get_at_add:
                self.logs = jax.device_get(self.logs)

            if self.provide_mean:
                mean = jax.tree_map(lambda *args: np.mean(np.stack(args)), *self.logs)
                mean = dict(**mean)
                mean = metrics_assign_group(mean, "aggregated/mean", -1)
                metrics.update(mean)
                std = jax.tree_map(lambda *args: np.std(np.stack(args)), *self.logs)
                std = dict(**std)
                std = metrics_assign_group(std, "aggregated/std", -1)
                metrics.update(std)

                sample_size = jax.tree_map(lambda *args: len(args), *self.logs)
                sample_size = dict(**sample_size)
                sample_size = metrics_assign_group(
                    sample_size, "aggregated/sample_size", -1
                )
                metrics.update(sample_size)

            if self.provide_latest:
                last = dict(**self.logs[-1])
                last = metrics_assign_group(last, "last", -1)
                metrics.update(last)

            if self.reset_on_get:
                self.logs = []

            perplexity = {}
            for k, v in metrics.items():
                suffix = "aggregated/mean/loss"

                if k.endswith(suffix):
                    new_k_base = k[: -len(suffix)]
                    new_k = new_k_base + "aggregated/total_perplexity"
                    perplexity[new_k] = np.exp(v)

            metrics.update(perplexity)

        return metrics

    def get_logs_per_field(self, field: str):
        return [log[field] for log in self.logs]


def compute_bucketed_log_pplx(token_level_losses: list):
    bucket_list = []

    # Flatten each 2D array in the list
    flattened_batches = [batch.flatten() for batch in token_level_losses]

    # Calculate the length of a single flattened array (assuming all have the same length)
    single_seq_len = len(flattened_batches[0])

    start_idx = 0
    end_idx = 16  # Initialize to the first boundary, you can change this constant

    # Loop through each bucket
    while end_idx <= single_seq_len:
        # Initialize variables to store total loss and count for averaging
        total_loss = 0.0
        count = 0

        # Loop through each flattened batch
        for flattened_batch in flattened_batches:
            # Extract the corresponding bucket from each batch
            bucket = flattened_batch[start_idx:end_idx]

            # Calculate the sum of all the losses in the bucket
            total_loss += np.sum(bucket)

            # Update the count for averaging
            count += len(bucket)

        # Calculate average loss
        avg_loss = total_loss / count if count != 0 else 0.0

        # Append the information as a dictionary to the list
        bucket_list.append(
            {"min_index": start_idx, "max_index": end_idx, "avg_log_pplx": avg_loss}
        )

        # Update start and end indices for the next bucket
        start_idx = end_idx
        end_idx *= 2

    return bucket_list


def list_to_dict_perplexity(bucket_list):
    perplexity_dict = {}
    for bucket in bucket_list:
        min_index = bucket["min_index"]
        avg_log_pplx = bucket["avg_log_pplx"]

        # Create a key with the specified format
        key = f"aggregated/mean/log_ppl_{min_index}"
        # Add to dictionary
        perplexity_dict[key] = avg_log_pplx

    return perplexity_dict
