# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

from absl.testing import absltest
from absl.testing import parameterized
import ml_collections
import tensorflow_datasets as tfds

from vit_jax import test_utils
from vit_jax import train
from vit_jax.configs import common
from vit_jax.configs import models

# from PIL import Image
# import numpy as np
# Image.fromarray(np.array([[[0, 0, 0]]], np.uint8)).save('black1px.jpg')
# print(repr(file('black1px.jpg', 'rb').read()))
JPG_BLACK_1PX = (b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c'
                 b' $.\' '
                 b'",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\x00\x01\x00\x01\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&\'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xc4\x00\x1f\x01\x00\x03\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x11\x00\x02\x01\x02\x04\x04\x03\x04\x07\x05\x04\x04\x00\x01\x02w\x00\x01\x02\x03\x11\x04\x05!1\x06\x12AQ\x07aq\x13"2\x81\x08\x14B\x91\xa1\xb1\xc1\t#3R\xf0\x15br\xd1\n\x16$4\xe1%\xf1\x17\x18\x19\x1a&\'()*56789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xda\x00\x0c\x03\x01\x00\x02\x11\x03\x11\x00?\x00\xf9\xfe\x8a(\xa0\x0f\xff\xd9')  # pylint: disable=line-too-long


class TrainTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('tfds', 'tfds'),
      ('directory', 'directory'),
  )
  def test_train_and_evaluate(self, dataset_source):
    config = common.get_config()
    config.model = models.get_testing_config()
    config.batch = 64
    config.accum_steps = 2
    config.batch_eval = 8
    config.total_steps = 1

    with tempfile.TemporaryDirectory() as workdir:
      if dataset_source == 'tfds':
        config.dataset = 'cifar10'
        config.pp = ml_collections.ConfigDict({
            'train': 'train[:98%]',
            'test': 'test',
            'crop': 224
        })
      elif dataset_source == 'directory':
        config.dataset = os.path.join(workdir, 'dataset')
        config.pp = ml_collections.ConfigDict({'crop': 224})
        for mode in ('train', 'test'):
          for class_name in ('test1', 'test2'):
            for i in range(8):
              path = os.path.join(config.dataset, mode, class_name, f'{i}.jpg')
              os.makedirs(os.path.dirname(path), exist_ok=True)
              with open(path, 'wb') as f:
                f.write(JPG_BLACK_1PX)
      else:
        raise ValueError(f'Unknown dataset_source: "{dataset_source}"')

      config.pretrained_dir = workdir
      test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz')

      _ = train.train_and_evaluate(config, workdir)
      self.assertTrue(os.path.exists(f'{workdir}/checkpoint_1'))


if __name__ == '__main__':
  absltest.main()
