import functools

import torch as th
import numpy as np
from rich.progress import track

from . import logger

def get_simple_t(skip_type, t_T, t_0, N):
    assert skip_type == 'time_uniform'
    return th.linspace(t_T, t_0, N + 1)

def get_model_input_time(t, total_N):
    return (t - 1. / total_N) * 1000.

def cf_loss_in_dpm_solver_steps(batch, cond_, model, diffusion, step_respacing, dpm_solver_steps, micro_size, device):
    t_0 = 1. / float(step_respacing)
    N = dpm_solver_steps
    simple_steps = get_simple_t('time_uniform', 1., t_0, N)
    step_list = get_model_input_time(simple_steps, float(step_respacing))

    t_rd = (step_list+0.5).type(th.int64)
    t_int = step_list.type(th.int64)
    t_ = th.cat((t_int,t_rd), axis=0)
    t_ = list(set(t_.numpy()))
    t_.remove(0)
    t_ = th.tensor(np.sort(np.array(t_)))
    step_list = t_.type(th.int64)
    logger.log('The dpm solver step list is', list(step_list.cpu().numpy()))
    batch = batch.to(device)
    cond = {
        k: v.to(device)
        for k, v in cond_.items()
    }

    loss = []
    mse = []
    vb = []
    for i in range(step_list.shape[0]):
        logger.log("Use step:", step_list[i].item(),"to compute training loss.")
        logger.log("Batch size is",batch.shape[0],", micro size is", micro_size)
        logger.log("There will be",int(batch.shape[0]/micro_size),"samples every step.")
        t_input_ = step_list[i]
        t_input = t_input_.expand(batch.shape[0])
        t_input = t_input.to(device)

        sum_loss = 0.
        sum_mse = 0.
        sum_vb = 0.
        for i in track(range(int(batch.shape[0]/micro_size))):
            micro_b = batch[i*micro_size : (i+1)*micro_size]
            micro_t = t_input[i*micro_size : (i+1)*micro_size]
            micro_c = {
                k: v[i*micro_size : (i+1)*micro_size]
                for k, v in cond.items()
            }

            compute_losses = functools.partial(
                diffusion.training_losses,
                model,
                micro_b,
                micro_t,
                model_kwargs=micro_c,
            )
            losses = compute_losses()
            sum_loss += losses["loss"].mean().item()
            sum_mse += losses["mse"].mean().item()
            if "vb" in losses:
                sum_vb += losses["vb"].mean().item()

        logger.log("###############################")
        loss.append(sum_loss)
        mse.append(sum_mse)
        if "vb" in losses:
            vb.append(sum_vb)
    logger.log("Sample with this model finished.")
    logger.log("##############################################################")
        
    return list(step_list.cpu().numpy()), loss, mse, vb
