import json
import uuid

from utils.types import JSON

base_data: str = 'PATH/samples/cifar10-32x32'
base_real_data: str = 'PATH'
base_results: str = 'results'
num_sample_steps: int = 18

model_load_path: str = 'PATH/results/lines/lines-cifar10-32x32-2-loss-lpips-vgg-64-bs-256-br-2-22074219-5353-4ac4-bd27-214df672cda5/checkpoints/last/index.pth'

def get_config(run_id: str) -> JSON:
    return {
        'log_filepath': f'{base_results}/{run_id}/log/log.log',
        'tensorboard_log_dir': f'{base_results}/{run_id}/tensorboard/logs',
        'train_steps': 100_000_000_000,
        'batch_size': 512,
        'batch_repeats': 1,
        'data_loader_workers': 4,
        'dataset': {'dataset_name': 'cifar10-32x32'},
        'report': {'train_steps': 50, 'test_steps': None},
        'checkpoint': {
            'folder': f'{base_results}/{run_id}/checkpoints',
            'save_steps': 50_000,
            'last_steps': 10_000,
            'save_time': None,
            'last_time': 4 * 60 * 60
        },
        'edm_scheduler': {},
        'edm_sampler': {'num_steps': num_sample_steps},
        'model': {
            'name': 'edm-cifar10-32x32-cond-vp',
            'load_path': model_load_path,
            'load_keys': ['model']
        },
        'discriminator': {},
        'discriminator_optimizer': {
            'name': 'r-adam',
            'learning_rate': 0.002,
            'params': {'betas': (0.5, 0.9)}
        },
        'train_dataset': [
            {
                'noise': {
                    'folder': f'{base_data}/noises',
                    'num_samples': 1_000_000
                },
                'image': {
                    'folder': f'{base_data}/edm-cifar10-32x32-cond-vp',
                    'num_samples': 1_000_000
                },
                'label': {
                    'folder': f'{base_data}/labels',
                    'num_samples': 1_000_000
                },
                'time_step': i
            } for i in range(num_sample_steps)
        ],
        'test_dataset': [
            {
                'noise': {
                    'folder': f'{base_data}/noises',
                    'num_samples': 10_000,
                    'start_index': 1_000_000
                },
                'image': {
                    'folder': f'{base_data}/edm-cifar10-32x32-cond-vp',
                    'num_samples': 10_000,
                    'start_index': 1_000_000
                },
                'label': {
                    'folder': f'{base_data}/labels',
                    'num_samples': 10_000,
                    'start_index': 1_000_000
                },
                'time_step': 0
            }
        ],
        'train_real_dataset': [
            {
                'image': {
                    'folder': f'{base_real_data}/train/images',
                    'num_samples': 50_000
                },
                'label': {
                    'folder': f'{base_real_data}/train/labels',
                    'num_samples': 50_000
                }
            }
        ],
        'test_real_dataset': [
            {
                'image': {
                    'folder': f'{base_real_data}/test/images',
                    'num_samples': 10_000
                },
                'label': {
                    'folder': f'{base_real_data}/test/labels',
                    'num_samples': 10_000
                }
            }
        ],
    }


def main() -> None:
    name: str = 'discriminator-cifar10-32x32-1-bs-512-br-1'
    config: JSON = get_config(f'{name}-{uuid.uuid4()}')
    print(config)
    with open(f'configs/{name}-{uuid.uuid4()}.json', 'w') as file:
        json.dump(config, file)


if __name__ == '__main__':
    main()
