import json
import uuid

import numpy as np

from utils.types import JSON

base_data: str = 'PATH/samples/imagenet-64x64'
base_real_data: str = 'PATH/datasets/imagenet-64x64'
base_results: str = 'results'
num_sample_steps: int = 40
train_steps: list[int] = [0, 1, 2]

load_path: str = 'PATH/results/lines-gan-imagenet-64x64-11-bs-16-br-32-2559fdee-2ed7-40ad-8959-2b7c1dfe6d7d.0/checkpoints/last'


def get_config(run_id: str) -> JSON:
    return {
        'base_folder': f'{base_results}/{run_id}',
        'train_steps': 100_000_000_000,
        'batch_size': 16,
        'batch_repeats': 32,
        'data_loader_workers': 4,
        'validate_ddp_consistency_steps': 200,
        'free_memory_every_sub_step': True,
        'dataset': {'dataset_name': 'imagenet-64x64'},
        'report': {'train_steps': 20},
        'checkpoint': {
            'save_steps': 5_000,
            'last_steps': 1_000
        },
        'edm_scheduler': {},
        'edm_sampler': {'num_steps': num_sample_steps},
        'model': {
            'name': 'edm-imagenet-64x64-cond-adm',
            'load_path': f'{load_path}/model.pth'
        },
        'model_optimizer': {
            'name': 'r-adam',
            'learning_rate': 0.000008,
            'load_path': f'{load_path}/model_optimizer.pth'
        },
        'discriminator': {
            'discriminator_load_path': f'{load_path}/discriminator.pth',
            'discriminator_feature_extractor_load_path': f'{load_path}/discriminator_feature_extractor.pth',
            'conditional': True
        },
        'discriminator_optimizer': {
            'name': 'r-adam',
            'learning_rate': 0.002,
            'params': {'betas': (0.5, 0.9)},
            'load_path': f'{load_path}/discriminator_optimizer.pth'
        },
        'ema': [
            {
                'beta': 0.999,
                'update_every': 1,
                'load_path': f'{load_path}/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_0.66667.pth'
            },
            {
                'beta': 0.999,
                'update_every': 1,
                'power': 0.75,
                'load_path': f'{load_path}/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_0.75.pth'
            },
            {
                'beta': 0.999,
                'update_every': 1,
                'power': 1.0,
                'load_path': f'{load_path}/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_1.0.pth'
            },
            {
                'beta': 0.999,
                'update_every': 1,
                'update_after_step': 0,
                'power': 1.0,
                'start_step': 100,
                'load_path': f'{load_path}/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_1.0.pth'
            }
        ],
        'fid': {
            'steps': None,
            'batch_size': 5,
            'reference_path': 'PATH/imagenet-64x64-edm.npz'
        },
        'reconstruction_loss': {'name': 'lpips-vgg', 'params': {'size': 224}},
        'reconstruction_lambda': 0.5,
        'amp': {
            'use_fp16': True,
            'use_autocast': False,
            'use_gard_scaler': True
        },
        'distributed_sampler_seed': int(np.random.randint(0, 2 ** 31)),
        'distributed_sampler_real_seed': int(np.random.randint(0, 2 ** 31)),
        'train_dataset': [
            {
                'noise': {
                    'folder': f'{base_data}/noises',
                    'num_samples': 4_000_000
                },
                'image': {
                    'folder': f'{base_data}/edm-imagenet-64x64-cond-adm',
                    'num_samples': 4_000_000
                },
                'label': {
                    'folder': f'{base_data}/labels',
                    'num_samples': 4_000_000
                },
                'time_step': i
            } for i in train_steps
        ],
        'test_dataset': [
            {
                'noise': {
                    'folder': f'{base_data}/noises',
                    'num_samples': 10,
                    'start_index': 5_000_000
                },
                'image': {
                    'folder': f'{base_data}/edm-imagenet-64x64-cond-adm',
                    'num_samples': 10,
                    'start_index': 5_000_000
                },
                'label': {
                    'folder': f'{base_data}/labels',
                    'num_samples': 10,
                    'start_index': 5_000_000
                },
                'time_step': 0
            }
        ],
        'train_real_dataset': [
            {
                'image': {
                    'folder': f'{base_real_data}/images',
                    'num_samples': 1281167
                },
                'label': {
                    'folder': f'{base_real_data}/labels',
                    'num_samples': 1281167
                }
            }
        ],
        'fid_train_dataset': [
            {
                'noise': {
                    'folder': f'{base_data}/noises',
                    'num_samples': 50_000
                },
                'label': {
                    'folder': f'{base_data}/labels',
                    'num_samples': 50_000
                }
            }
        ],
        'fid_test_dataset': [
            {
                'noise': {
                    'folder': f'{base_data}/noises',
                    'num_samples': 50_000,
                    'start_index': 5_000_000
                },
                'label': {
                    'folder': f'{base_data}/labels',
                    'num_samples': 50_000,
                    'start_index': 5_000_000
                }
            }
        ]
    }


def main() -> None:
    name: str = 'lines-gan-imagenet-64x64-15-bs-16-br-32-three-step'
    config: JSON = get_config(f'{name}-{uuid.uuid4()}')
    print(config)
    with open(f'configs/{name}-{uuid.uuid4()}.json', 'w') as file:
        json.dump(config, file)


if __name__ == '__main__':
    main()
