import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import graph_lib
from model import utils as mutils


def get_loss_fn(noise, graph, train, sampling_eps=1e-3, lv=False,loss_type='score_entropy',order = torch.arange(1024)):

    def t_entropy_loss(model, batch, cond=None, t=None, perturbed_batch=None):
        """
        Batch shape: [B, L] int. D given from graph
        """

        if t is None:
            if lv:
                raise NotImplementedError("Yeah I gotta do this later")
            else:
                t = (1 - sampling_eps) * torch.rand(batch.shape[0], device=batch.device) + sampling_eps
            
        sigma, dsigma = noise(t)
        
        if perturbed_batch is None:
            perturbed_batch = graph.sample_transition(batch, sigma[:, None])

        log_score_fn = mutils.get_score_fn(model, train=train, sampling=False)
        log_score = log_score_fn(perturbed_batch, sigma)
        if loss_type== 'cross_entropy':
            loss = graph.cross_entropy(log_score, sigma[:, None], perturbed_batch, batch)
        elif loss_type == 'score_entropy':
            loss = graph.score_entropy(log_score, sigma[:, None], perturbed_batch, batch)
        else:
            raise NotImplementedError(f"Loss type {loss_type} not implemented yet!")
        loss = (dsigma[:, None] * loss).sum(dim=-1)

        return loss

    def lambada_entropy_loss(model, batch, cond = None):
        lambada = torch.rand(batch.shape[0], device=batch.device)
        perturbed_batch = graph.sample_transition_lambada(batch, lambada[...,None])
        output = - model.get_log_condition(perturbed_batch)
        rel_ind = perturbed_batch == graph.dim - 1
        other_ind = batch[rel_ind]
        loss = torch.zeros(*batch.shape, device=batch.device,dtype = output.dtype)
        loss[rel_ind] = torch.gather(output[rel_ind], -1, other_ind[..., None]).squeeze(-1)
        loss = loss.sum(dim = -1)/lambada
        return loss

    def ar_loss(model, batch):
        loss = 0
        for i in range(batch.shape[1]):
            masked_batch = batch.clone()
            masked_batch[:,order[i:]] = graph.dim - 1
            p_log_condition_i = model.get_log_condition(masked_batch)[:,order[i]]
            loss += - p_log_condition_i[torch.arange(batch.shape[0]),batch[:,order[i]]]
        return loss
    if loss_type == 'ar':
        return ar_loss
    elif loss_type =='lambada_cross_entropy':
        return lambada_entropy_loss
    else:
        return t_entropy_loss


def get_optimizer(config, params):
    if config.optim.optimizer == 'Adam':
        optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps,
                               weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'AdamW':
        optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps,
                               weight_decay=config.optim.weight_decay)
    else:
        raise NotImplementedError(
            f'Optimizer {config.optim.optimizer} not supported yet!')

    return optimizer


def optimization_manager(config):
    """Returns an optimize_fn based on `config`."""

    def optimize_fn(optimizer, 
                    scaler, 
                    params, 
                    step, 
                    lr=config.optim.lr,
                    warmup=config.optim.warmup,
                    grad_clip=config.optim.grad_clip):
        """Optimizes with warmup and gradient clipping (disabled if negative)."""
        scaler.unscale_(optimizer)

        if warmup > 0:
            for g in optimizer.param_groups:
                g['lr'] = lr * np.minimum(step / warmup, 1.0)
        if grad_clip >= 0:
            torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)

        scaler.step(optimizer)
        scaler.update()

    return optimize_fn


def get_step_fn(noise, graph, train, optimize_fn, accum, loss_type):
    loss_fn = get_loss_fn(noise, graph, train, loss_type = loss_type)

    accum_iter = 0
    total_loss = 0

    def step_fn(state, batch, cond=None):
        nonlocal accum_iter 
        nonlocal total_loss

        model = state['model']

        if train:
            optimizer = state['optimizer']
            scaler = state['scaler']
            loss = loss_fn(model, batch, cond=cond).mean() / accum
            
            scaler.scale(loss).backward()

            accum_iter += 1
            total_loss += loss.detach()
            if accum_iter == accum:
                accum_iter = 0

                state['step'] += 1
                optimize_fn(optimizer, scaler, model.parameters(), step=state['step'])
                state['ema'].update(model.parameters())
                optimizer.zero_grad()
                
                loss = total_loss
                total_loss = 0
        else:
            with torch.no_grad():
                ema = state['ema']
                ema.store(model.parameters())
                ema.copy_to(model.parameters())
                loss = loss_fn(model, batch, cond=cond).mean()
                ema.restore(model.parameters())

        return loss

    return step_fn