# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
import numpy as np
import math


def sample_to_batch(mod_dict, device, domains):
    mod_dict = {
        modality: {k: v.unsqueeze(0).to(device, non_blocking=True) for k, v in d.items()}
        for modality, d in mod_dict.items() if modality in domains
    }
    
    return mod_dict


def unbatch(tensor):
    return tensor.detach().squeeze(0).cpu()


def batch_to_sample(mod_dict, domains):
    mod_dict = {
        modality: {k: unbatch(v) for k, v in d.items()}
        for modality, d in mod_dict.items() if modality in domains
    }
    
    return mod_dict


def batch_to_device(mod_dict, device, domains):
    mod_dict = {
        modality: {k: v.to(device, non_blocking=True) for k, v in d.items()}
        for modality, d in mod_dict.items() if modality in domains
    }
    
    return mod_dict


def cosine_schedule(num_steps, total_tokens):
    iters = np.arange(num_steps)
    base_value = 1
    final_value = 0
    schedule = np.array(
        [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
    schedule_tokens = [round(total_tokens * i) for i in (schedule[:-1] - schedule[1:])]
    schedule_tokens.append(total_tokens - sum(schedule_tokens))
    return np.array(schedule_tokens)


def linear_schedule(num_steps, total_tokens):
    schedule = np.linspace(0, total_tokens, num_steps + 1, dtype=int)
    schedule_tokens = np.diff(schedule)[::-1]
    schedule_tokens.sort()  # Sorts the array in ascending order.
    schedule_tokens = schedule_tokens[::-1]  # Reverses the array to descending order.
    return np.trim_zeros(schedule_tokens, 'b')  # Trims trailing zeros.


def continue_schedule(schedule, num_current_tokens):
    schedule_cumsum = np.cumsum(schedule)
    keep_mask = schedule_cumsum > num_current_tokens
    diff = schedule_cumsum[keep_mask][0] - num_current_tokens
    new_schedule = schedule[keep_mask]
    new_schedule[0] = diff
    return new_schedule


def decreasing_temp_schedule(max, min, token_schedule):
    schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule)
    temp_schedule = np.array([min + (max - min) * (1 - s) for s in schedule_cumsum])
    return temp_schedule


def onex_temp_schedule(max_t, min_t, token_schedule, power=0.5, min_linspace=1, max_linspace=100):
    """Abitrary temperature schedule for one over x"""
    x = np.linspace(min_linspace, max_linspace, num=sum(token_schedule))
    y = 1/(x**power)
    y = y - min(y)
    y = y / max(y)
    unscaled_schedule = y
    schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule)
    unscaled_schedule = [(1 - cs) * us for us, cs in zip(unscaled_schedule, schedule_cumsum)]

    temp_schedule = np.array([min_t + (max_t - min_t) * s for s in unscaled_schedule]).clip(min=1e-9)
    return temp_schedule


def linear_temp_schedule(temp, token_schedule):
    """ Temperature that decays the temperature inversely proportional to the token schedule. """
    return np.concatenate([np.array([temp * 1.0]), (temp * (token_schedule.sum() - token_schedule.cumsum()) / token_schedule.sum())[:-1]]).clip(min=1e-9)
