from typing import Iterable
from functools import partial

from jax import lax, random
import jax.numpy as jnp
import equinox as eqx

from loss.llt_loss import compute_loss_llt, compute_loss_llt_with_cond
from loss.llt_minus_A_loss import compute_loss_llt_minus_A, compute_loss_llt_minus_A_with_cond

from utils import batch_indices

def train(model, data, train_config, loss_name, key=42, repeat_step=1, with_cond=True):
    assert isinstance(train_config, dict)
    assert isinstance(data, Iterable)
    assert len(data) == 4
    X_train, X_test, y_train, y_test = data
    assert isinstance(X_train, Iterable)
    assert isinstance(X_test, Iterable)
    
    optim = train_config['optimizer'](train_config['lr'], **train_config['optim_params'])
    opt_state = optim.init(eqx.filter(model, eqx.is_array))
    batch_size = train_config['batch_size']
    assert len(X_train[1]) >= batch_size, 'Batch size is greater than the dataset size'
    
    if loss_name == 'llt':
        compute_loss = partial(compute_loss_llt)
        if with_cond:
            compute_loss_cond = partial(compute_loss_llt_with_cond, repeat_step=repeat_step, )
        else:
            compute_loss_cond = lambda model, X, y: (compute_loss(model, X, y), 1)  
    elif loss_name == 'llt_minus_A':
        compute_loss = partial(compute_loss_llt_minus_A)
        if with_cond:
            compute_loss_cond = partial(compute_loss_llt_minus_A_with_cond, repeat_step=repeat_step, )
        else:
            compute_loss_cond = lambda model, X, y: (compute_loss(model, X, y), 1)      
    else:
        raise ValueError('Invalid loss name.')
    compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)
    
    def make_val_step(model, X, y):
        loss, cond = compute_loss_cond(model, X, y)
        return loss, cond
    
    def make_step(carry, ind):
        model, opt_state = carry
        batched_X = [arr[ind, ...] for arr in X_train]
        
        loss, grads = compute_loss_and_grads(model, batched_X, y_train)
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return (model, opt_state), loss
    
    def train_body(carry, x):
        model, opt_state = carry
        key = random.PRNGKey(x)
        b = batch_indices(key, X_train[0], batch_size)
#         b_test = batch_indices(key, X_test[0], batch_size)
        
        carry_inner_init = (model, opt_state)
        (model, opt_state), loss_train = lax.scan(make_step, carry_inner_init, b)
#         model, (loss_test, cond_test) = lax.scan(make_val_step, model, b_test)
        loss_test, cond_test = make_val_step(model, X_test, y_test)
        return (model, opt_state), [jnp.mean(loss_train), loss_test, cond_test] 
    
    carry_init = (model, opt_state)
    (model, _), losses = lax.scan(train_body, carry_init, jnp.arange(train_config['epoch_num']))
    return model, losses