import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import graph_lib
from model import utils as mutils
import ot
import time


def get_loss_fn(noise, graph, train, config, sampling_eps=1e-3, lv=False):

    def loss_fn(model, databatch, cond=None, t=None, perturbed_batch=None):
        """
        Batch shape: [B, L] int. D given from graph
        """
        B, L = databatch.shape
        sourcebatch = graph.sample_limit(databatch.shape).to(databatch)

        use_optimal_transport=True

        if use_optimal_transport:
            try:
                sourcebatch_embedd = model.module.vocab_embed(sourcebatch).reshape(B, -1).detach()
                databatch_embedd = model.module.vocab_embed(databatch).reshape(B, -1).detach()
                sourcebatch_norm = (sourcebatch_embedd ** 2).sum(dim=1).reshape(-1, 1)  # Shape (B, 1)
                databatch_norm = (databatch_embedd ** 2).sum(dim=1).reshape(1, -1)  # Shape (1, B)
                M = sourcebatch_norm + databatch_norm - 2 * sourcebatch_embedd @ databatch_embedd.T 
                M = torch.clamp(M, min=0)
                M = M / M.max()

                a = np.ones(B) / B
                b = np.ones(B) / B
                
                sinkhorn_plan = ot.sinkhorn(a, b, M.cpu().numpy(), 0.0101)
                sinkhorn_plan = torch.from_numpy(sinkhorn_plan)

                flattened_plan = sinkhorn_plan.flatten()
                num_samples = B
                indices = torch.multinomial(flattened_plan, num_samples, replacement=True)

                sourcebatch_indices = indices // sinkhorn_plan.shape[1]
                databatch_indices = indices % sinkhorn_plan.shape[1]

                sourcebatch = sourcebatch[sourcebatch_indices]
                databatch = databatch[databatch_indices]
            except:
                print('Error calculating optimal transport, continuing with independent sampling this batch')

        t = (1 - sampling_eps) * torch.rand(databatch.shape[0], device=databatch.device)
        
        if perturbed_batch is None:
            perturbed_batch = graph.sample_transition(sourcebatch, databatch, t[:, None])

        log_score_fn = mutils.get_score_fn(model, train=train, sampling=False)
        log_score = log_score_fn(perturbed_batch, t)

        log_score = log_score.view(-1, log_score.shape[-1])
        databatch = databatch.view(-1)  
        criterion = torch.nn.CrossEntropyLoss(reduction='none')

        dim = config.tokens+1
        mask = perturbed_batch>=(dim-1)
        mask = (mask*1)

        loss = criterion(log_score, databatch)
        loss = loss.view(*perturbed_batch.shape)*mask
        loss = loss.sum(dim=-1)
        return loss  

    return loss_fn


def get_optimizer(config, params):
    if config.optim.optimizer == 'Adam':
        optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps,
                               weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'AdamW':
        optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, config.optim.beta2), eps=config.optim.eps,
                               weight_decay=config.optim.weight_decay)
    else:
        raise NotImplementedError(
            f'Optimizer {config.optim.optimizer} not supported yet!')

    return optimizer


def optimization_manager(config):
    """Returns an optimize_fn based on `config`."""

    def optimize_fn(optimizer, 
                    scaler, 
                    params, 
                    step, 
                    lr=config.optim.lr,
                    warmup=config.optim.warmup,
                    grad_clip=config.optim.grad_clip):
        """Optimizes with warmup and gradient clipping (disabled if negative)."""
        scaler.unscale_(optimizer)

        if warmup > 0:
            for g in optimizer.param_groups:
                g['lr'] = lr * np.minimum(step / warmup, 1.0)
        if grad_clip >= 0:
            torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)

        scaler.step(optimizer)
        scaler.update()

    return optimize_fn


def get_step_fn(noise, graph, train, optimize_fn, accum, config):
    loss_fn = get_loss_fn(noise, graph, train, config)

    accum_iter = 0
    total_loss = 0

    def step_fn(state, batch, cond=None):
        nonlocal accum_iter 
        nonlocal total_loss

        model = state['model']

        if train:
            optimizer = state['optimizer']
            scaler = state['scaler']
            loss = loss_fn(model, batch, cond=cond).mean() / accum
            
            scaler.scale(loss).backward()

            accum_iter += 1
            total_loss += loss.detach()
            if accum_iter == accum:
                accum_iter = 0

                state['step'] += 1
                optimize_fn(optimizer, scaler, model.parameters(), step=state['step'])
                state['ema'].update(model.parameters())
                optimizer.zero_grad()
                
                loss = total_loss
                total_loss = 0
        else:
            with torch.no_grad():
                ema = state['ema']
                ema.store(model.parameters())
                ema.copy_to(model.parameters())
                loss = loss_fn(model, batch, cond=cond).mean()
                ema.restore(model.parameters())

        return loss

    return step_fn
