"""This is an integration end to end test for the trainer"""

import os
from dataclasses import dataclass
from flax import linen as nn
import jax
from jax import numpy as jnp
from torch.utils.data import Dataset
from latte_trans.evals.base import Evaluator
from latte_trans.trainer.dfsp_jax import DFSDPTrainer, shard_module_params


class FSDPClassifier(nn.Module):
    hidden_size: int = 128
    dropout_rate: float = 0.1
    num_classes: int = 10
    dtype: jnp.dtype = jnp.bfloat16

    @nn.compact
    def __call__(self, x: jax.Array, labels: jax.Array, train: bool) -> jax.Array:
        sharded_dense = shard_module_params(
            nn.Dense,
            axis_name="B",
            min_weight_size=12,
        )
        x = sharded_dense(
            features=self.hidden_size,
            dtype=self.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = sharded_dense(
            features=self.num_classes,
            dtype=self.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        y = x.sum()
        return {"loss": y, "logits": x}


class DummyEval(Evaluator):
    def __init__(self) -> None:
        pass

    def compute_metrics(self, *args, **kwargs):
        return -1

    def evaluate(self, trainer_eval_fn, prefix="eval_", **kwargs):
        return {prefix + "loss": 1.0}


class DummyDataset(Dataset):
    def __init__(self, key) -> None:
        self.x = jax.random.normal(key, (100, 10, 1))

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, index):
        return {"input_ids": self.x[index], "labels": jnp.array([1])}


@dataclass
class Config:
    batchnorm: bool = False
    epochs: int = 5
    eval_steps: int = 4
    batch_size: int = 2
    max_checkpoints: int = 2
    shuffle_train: bool = False
    train_steps: int = None
    grad_accumulation_steps: int = 1
    warmup: int = 2
    lr_decay_fn: str = "cosine"
    lr: float = 1e-3
    weight_decay: float = 0.04
    max_seq_len: int = 100


def data_collator(batch):
    input_ids, labels = [], []
    for element in batch:
        input_ids.append(element["input_ids"])
        labels.append(element["labels"])
    return {"input_ids": jnp.array(input_ids), "labels": jnp.array(labels)}


def test_trainer():
    key = jax.random.PRNGKey(seed=0)
    init_key, train_key, data_key, key = jax.random.split(key, 4)
    train_data = DummyDataset(data_key)
    model = FSDPClassifier()
    config = Config()
    evaluator = DummyEval()
    out_dir = os.path.dirname(os.path.abspath(__file__))
    trainer = DFSDPTrainer(
        config=config,
        out_dir=out_dir,
        model=model,
        train_data=train_data,
        data_collator=data_collator,
        evaluator=evaluator,
        rng=init_key,
        model_inputs_orded=("input_ids", "labels"),
    )
    trainer.train(train_key)


if __name__ == "__main__":
    test_trainer()


"""
pdm run python3 /home/user/latte_trans/tests/trainer/test_dfsp.py
pdm run python -m tests.trainer.test_dfsp

"""
