import os
import sys
import functools
from typing import Optional

import torch
import yaml

from spc.dataset import LabelledDataset
from spc.model import Embedder, make_embeddings, RandomTest, MetadataWeights, MAX_SAMPLE_SIZE
from spc.resnet import resnet18
from spc.train import train_model
from spc.evaluate import evaluate_model
from spc.loading import load_transforms, load_loss_func, load_optimizer
from spc.visualization import save_combined_visuals


EXPERIMENTS_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), 'experiments'))


def main(config_filepath: Optional[str] = None):
    if config_filepath is None and len(sys.argv) == 1:
        raise ValueError("Must provide config filepath as command line argument or as first argument to main()")

    if config_filepath is None:
        config_filepath = sys.argv[1]

    with open(config_filepath) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    experiment_name, _ = os.path.splitext(os.path.basename(config_filepath))
    experiment_folder = os.path.join(
        EXPERIMENTS_FOLDER,
        os.path.basename(os.path.dirname(config_filepath)),
        experiment_name,
    )

    os.makedirs(experiment_folder, exist_ok=True)

    # copy config file to experiment folder
    config_filename = os.path.basename(config_filepath)
    config_dest = os.path.join(experiment_folder, config_filename)
    with open(config_dest, 'w') as f:
        yaml.dump(config, f)

    select_channels = config.get('select_channels', None)

    train_dataset = LabelledDataset(
        csv_fpath=config['dataset']['csv_filepath'],
        label_cols=config['dataset']['label_cols'],
        max_cache_size=None,
        dna_pct_threshold=config['dataset'].get('dna_pct_threshold', None),
        select_channels=select_channels,
    )

    eval_dataset = LabelledDataset(
        csv_fpath=config['eval_dataset']['csv_filepath'],
        label_cols=config['eval_dataset']['label_cols'],
        max_cache_size=None,
        dna_pct_threshold=config['eval_dataset'].get('dna_pct_threshold', None),
        select_channels=select_channels,
    )

    # check an image can be loaded
    test_im, _ = train_dataset[0]
    if select_channels is not None:
        n_channels = len(select_channels)
    else:
        n_channels = test_im.shape[0]

    norm_layer = torch.nn.BatchNorm2d
    if 'batch_norm' in config['model']:
        norm_layer = functools.partial(
            torch.nn.BatchNorm2d,
            track_running_stats=config['model']['batch_norm']['track_running_stats'],
            momentum=config['model']['batch_norm']['momentum'],
            affine=config['model']['batch_norm']['affine'],
        )

    embed_dim = config['model']['embed_dim']
    encoder_type = config['model']['encoder']
    if encoder_type == 'ResNet18':
        encoder = resnet18(norm_layer=norm_layer, in_channels=n_channels)
    elif encoder_type == 'RandomTest':
        encoder = RandomTest()
    else:
        raise ValueError(f"Unknown encoder: {encoder_type}")
    model = Embedder(
        encoder=encoder,
        encoder_out_dim=encoder.rep_dim,
        embed_dim=embed_dim,
        head_type=config['model'].get('head_type', 'mlp'),
    )

    md_weights = MetadataWeights(
        n_classes=train_dataset.nlabels(),
        embed_dim=embed_dim,
        norm_after_update=config['train_conf']['norm_after_update'],
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    md_weights.to(device)

    model_params = list(model.parameters())
    if 'class_loss_func' in config['train_conf']:
        class_loss_func = load_loss_func(config['train_conf']['class_loss_func'])
        class_loss_func.to(device)
        class_optimizer = load_optimizer(
            config['train_conf']['class_optimizer'],
            md_weights.parameters(),
        )
    else:
        class_loss_func = None
        class_optimizer = None
        model_params += list(md_weights.parameters())

    model_loss_func = load_loss_func(config['train_conf']['model_loss_func'], n_classes=train_dataset.nlabels())
    model_loss_func.to(device)
    # add bias terms for the edge case where we are using Xent loss with bias
    if config['train_conf']['model_loss_func']['type'] == 'Xent' and config['train_conf']['model_loss_func']['add_bias']:
        model_params += list(model_loss_func.parameters())
    model_optimizer = load_optimizer(
        config['train_conf']['model_optimizer'],
        model_params,
    )

    train_transforms = load_transforms(config['train_conf']['transforms'], device)
    eval_transforms = load_transforms(config['eval_conf']['transforms'], device)

    embedding_fn = functools.partial(
        make_embeddings,
        eval_transforms=eval_transforms,
        plate_sampling_strategy=config['eval_conf']['sampling_strategy'],
        eval_batch_size=128,
        max_sample_size=config['eval_conf'].get('max_sample_size', MAX_SAMPLE_SIZE),
    )

    early_stopping_metric = config['train_conf'].get('early_stopping_metric', None)
    report_bbbc021_metrics = config['eval_conf'].get('report_bbbc021_metrics', True)

    model = train_model(
        experiment_folder=experiment_folder,
        model=model,
        md_weights=md_weights,
        embedding_fn=embedding_fn,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        n_epochs=config['train_conf']['n_epochs'],
        batch_size=config['train_conf']['batch_size'],
        train_transforms=train_transforms,
        model_loss_func=model_loss_func,
        model_optimizer=model_optimizer,
        class_loss_func=class_loss_func,
        class_optimizer=class_optimizer,
        report_bbbc021_metrics=report_bbbc021_metrics,
        early_stopping_metric=early_stopping_metric,
    )

    # eval
    metrics = evaluate_model(
        experiment_folder=experiment_folder,
        model=model,
        dataset=train_dataset,
        embedding_fn=embedding_fn,
        save_visualizations=True,
        save_embeddings=config['eval_conf'].get('save_embeddings', True),
        report_bbbc021_metrics=report_bbbc021_metrics,
        prefix='train_',
    )

    eval_metrics = evaluate_model(
        experiment_folder=experiment_folder,
        model=model,
        dataset=eval_dataset,
        embedding_fn=embedding_fn,
        save_visualizations=True,
        save_embeddings=config['eval_conf'].get('save_embeddings', True),
        report_bbbc021_metrics=report_bbbc021_metrics,
        prefix='eval_',
    )
    metrics = {**metrics, **eval_metrics}

    save_combined_visuals(
        dataset1=train_dataset,
        dataset2=eval_dataset,
        experiment_folder=experiment_folder,
        embedding_fn=embedding_fn,
        model=model,
    )

    print("Final metrics:")
    print(metrics)

    # save metrics
    metrics_fpath = os.path.join(experiment_folder, 'metrics.yaml')
    with open(metrics_fpath, 'w') as f:
        yaml.dump(metrics, f)


if __name__ == '__main__':
    main()


