import json
import random

import click
import datasets
import matplotlib.pyplot as plt
import neptune
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm

from diagram import System
from models import Decoder, Encoder
from rules.autoencoder import identity
from utils import draw


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MODELS_MAP = {
    'Decoder': Decoder,
    'Encoder': Encoder,
}
OPTIMIZERS_MAP = {
    'AdamW': torch.optim.AdamW,
}
LR_SCHEDULERS_MAP = {
    'CosineAnnealingWarmRestarts': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
}



@click.command()
@click.option('--train-config-file',
              default='train_config.json',
              help='Path to training config JSON file.')
@click.option('--neptune-config-file',
              default='.secrets/neptune_config.json',
              help='Path to neptune config JSON file.')
def train(train_config_file: str, neptune_config_file: str):
    """Train an autoencoder on the specifed dataset."""

    train_config = json.load(open(train_config_file))
    run_config = json.load(open(neptune_config_file))

    # Init neptune run
    print('Starting run.')
    run_config['name'] = train_config['name']
    run_config['tags'] = train_config['tags']
    run_config['project'] = 'quantinuum-ml/AIAI'

    run = neptune.init_run(**run_config)

    seed = train_config.get('seed', 0)
    batch_size = train_config.get('batch_size', 64)
    epochs = train_config.get('epochs', 1000)
    log_interval = train_config.get('log_interval', 50)
    num_workers = train_config.get('num_workers', 8)
    input_size = train_config['arch']['encoder']['params']['input_size']
    clip_gradients = train_config.get('clip_gradients', 1)

    # Helper preprocessing functions
    transform = transforms.Compose([
        transforms.Resize((input_size, input_size),
                        interpolation=transforms.InterpolationMode.NEAREST),
        transforms.ToTensor()
    ])

    def _normalize(examples: dict) -> dict:
        examples['image'] = [transform(image) for image in examples['image']]
        return examples

    print('Preparing data.')
    dataset_name = train_config['dataset_name']
    dataset = datasets.load_dataset(dataset_name,
                                    split='train')

    dataset = list(tqdm(
        dataset.map(_normalize,
                    batched=True,
                    num_proc=num_workers).with_format('torch')
    ))

    print('Initialising systems and rules.')
    system = System()
    identity_loss = torch.nn.MSELoss()
    system.add_rule('identity',
                    identity,
                    identity_loss,
                    encoder='encoder',
                    decoder='decoder',
                    data='data')

    print('Initialising models.')
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    encoder_cls = train_config['arch']['encoder']['class']
    encoder_cls = MODELS_MAP[encoder_cls]
    encoder_params = train_config['arch']['encoder']['params']
    encoder_input_size = encoder_params.pop('input_size')
    encoder_latent_size = encoder_params.pop('latent_size')
    decoder_cls = train_config['arch']['decoder']['class']
    decoder_cls = MODELS_MAP[decoder_cls]
    decoder_params = train_config['arch']['decoder']['params']
    decoder_input_size = decoder_params.pop('input_size')
    decoder_latent_size = decoder_params.pop('latent_size')

    encoder = encoder_cls(encoder_input_size, encoder_latent_size,
                          **encoder_params).to(DEVICE)
    decoder = decoder_cls(decoder_input_size, decoder_latent_size,
                          **decoder_params).to(DEVICE)

    models = torch.nn.ModuleList([encoder, decoder])
    optimizer_cls = train_config['optimizer']['class']
    optimizer_cls = OPTIMIZERS_MAP[optimizer_cls]
    optimizer_params = train_config['optimizer']['params']
    optimizer = optimizer_cls(models.parameters(), **optimizer_params)

    lr_sched = train_config.get('lr_scheduler', None)
    if lr_sched is not None:
        lr_sched_cls = LR_SCHEDULERS_MAP[lr_sched['class']]
        scheduler = lr_sched_cls(optimizer,
                                 **lr_sched['params'])

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            drop_last=True,
                            num_workers=num_workers,
                            shuffle=True)

    run['params'] = train_config

    for rule_name, rule in system.rules.items():
        run['params'][f'rules/{rule_name}'] = {
            'loss_fn': type(rule.loss_function).__name__,
        }

    print('Starting training')

    step = 0

    def _log_step(encoder, decoder, results, image, run, epoch, step):
        print(f'Ep {step}\nLoss: {loss:g} <-',
                        ' '.join(f'{k}={v["loss"]:g}' for k, v in results.items()))

        with torch.no_grad():
            encoder.eval()
            decoder.eval()
            reconstructed = decoder(encoder(image.unsqueeze(0)))

        fig = draw(image, *reconstructed)
        description = f'{epoch = } | {step = }'
        run['train/figure'].append(fig, step=step, description=description)


    with tqdm(total=epochs) as progress_bar:
        for epoch in range(epochs):
            for batch in dataloader:
                data_batch = batch['image'].to(DEVICE)

                encoder.train()
                decoder.train()

                results = system(encoder=encoder,
                                 decoder=decoder,
                                 data=data_batch)
                losses = {k: v['loss'] for k, v in results.items()}

                loss = losses['identity']

                for k, v in losses.items():
                    run[f'loss/{k}'].append(v)

                optimizer.zero_grad()
                loss.backward()
                if clip_gradients:
                    torch.nn.utils.clip_grad_value_(models.parameters(),
                                                    clip_gradients)
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

                step += 1
                if step % log_interval == 0:
                    _log_step(encoder, decoder, results,
                              data_batch[0], run, epoch, step)

            progress_bar.update(1)

    print('Saving artifacts.')
    torch.save(encoder.state_dict(), 'checkpoints/encoder.pth')
    torch.save(decoder.state_dict(), 'checkpoints/decoder.pth')

    run['models/encoder'].upload('checkpoints/encoder.pth')
    run['models/decoder'].upload('checkpoints/decoder.pth')
    run.stop()

    print('Run ended.')


if __name__ == '__main__':
    train()
