import os
import time
from typing import Callable, Dict, Optional

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from spc.evaluate import evaluate_model
from spc.model import Embedder, MetadataWeights
from spc.dataset import LabelledDataset


def train_one_epoch(
        model: torch.nn.Module,
        md_weights: MetadataWeights,
        dataset: LabelledDataset,
        epoch: int,
        max_epoch: int,
        batch_size: int,
        train_transforms: Callable,
        model_loss_func: torch.nn.Module,
        model_optimizer: torch.optim.Optimizer,
        class_loss_func: Optional[torch.nn.Module] = None,
        class_optimizer: Optional[torch.optim.Optimizer] = None,
) -> float:
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=0,
    )

    running_loss = 0
    with tqdm(total=len(data_loader), desc=f'Epoch: {epoch}/{max_epoch}') as progress:
        for step, (x, labels) in enumerate(data_loader):
            model_optimizer.zero_grad()
            # sample wise augmentation
            x = torch.stack([train_transforms(x_i) for x_i in x], dim=0)

            z = model(x)
            w = md_weights()

            model_loss = model_loss_func(z, w, labels)
            model_loss.backward()
            model_optimizer.step()

            if class_loss_func is not None:
                class_optimizer.zero_grad()
                class_loss = class_loss_func(z.detach(), w, labels)
                class_loss.backward()
                class_optimizer.step()
                class_loss.item()  # unused

            if md_weights.norm_after_update:
                md_weights.renorm()

            running_loss += float(model_loss.item())

            progress.set_postfix({'loss': f"{running_loss / (step + 1):.3f}"})
            progress.update()
    return running_loss / len(data_loader)


def log_epoch(experiment_folder: str, epoch_data: Dict):
    log_fpath = os.path.join(experiment_folder, 'log.csv')
    if not os.path.exists(log_fpath):
        columns = list(epoch_data.keys())
        with open(log_fpath, 'w') as f:
            f.write(','.join(columns) + '\n')
    values = [str(v) for v in epoch_data.values()]
    with open(log_fpath, 'a') as f:
        f.write(','.join(values) + '\n')


def save_model(
        experiment_folder: str,
        model: torch.nn.Module,
        epoch: int,
        performance: float,
        metric: str = 'loss'
):
    # remove previous best model for this metric
    model_files = os.listdir(experiment_folder)
    for file in model_files:
        filename, ext = os.path.splitext(file)
        if ext == '.pt':
            old_metric, _ = filename.split('.', 1)
            if old_metric == metric:
                os.remove(os.path.join(experiment_folder, file))
    # save new best model
    encoded_performance = f"{performance:.4f}".replace('.', 'P')
    save_filename = f"{metric}.{epoch}.{encoded_performance}.pt"
    save_fpath = os.path.join(experiment_folder, save_filename)
    torch.save(model.state_dict(), save_fpath)


def load_model(model: Embedder, experiment_folder: str, early_stopping_metric: str) -> Embedder:
    for file in os.listdir(experiment_folder):
        filename, ext = os.path.splitext(file)

        if ext == '.pt':
            parts = filename.split('.', 2)
            old_metric, old_epoch, old_performance = parts
            old_performance = float(old_performance.replace('P', '.'))
            old_epoch = int(old_epoch)

            if old_metric == early_stopping_metric:
                print(f"Found previous model at epoch={old_epoch}, {old_metric}={old_performance:.4f}, reloading.")
                model.load_state_dict(torch.load(os.path.join(experiment_folder, file)))
                return model

    print(f"No previous models found.")
    return model


def train_model(
        experiment_folder: str,
        model: Embedder,
        md_weights: MetadataWeights,
        embedding_fn: Callable,
        train_dataset: LabelledDataset,
        eval_dataset: LabelledDataset,
        n_epochs: int,
        batch_size: int,
        train_transforms: Callable,
        model_loss_func: torch.nn.Module,
        model_optimizer: torch.optim.Optimizer,
        class_loss_func: Optional[torch.nn.Module],
        class_optimizer: Optional[torch.optim.Optimizer],
        report_bbbc021_metrics: bool,
        early_stopping_metric: str = 'loss',
) -> Embedder:
    best_loss = float('inf')
    best_metric_performance = -float('inf')

    loss = 0.0
    for epoch in range(1, n_epochs+1):
        start_epoch_time = time.time()

        model.train()
        loss = train_one_epoch(
            model=model,
            md_weights=md_weights,
            dataset=train_dataset,
            epoch=epoch,
            max_epoch=n_epochs,
            batch_size=batch_size,
            train_transforms=train_transforms,
            model_loss_func=model_loss_func,
            model_optimizer=model_optimizer,
            class_loss_func=class_loss_func,
            class_optimizer=class_optimizer,
        )
        end_epoch_time = time.time()

        model.eval()
        train_metrics = evaluate_model(
            experiment_folder=experiment_folder,
            model=model,
            dataset=train_dataset,
            embedding_fn=embedding_fn,
            save_visualizations=False,
            save_embeddings=False,
            report_bbbc021_metrics=report_bbbc021_metrics,
            prefix='train_',
        )
        eval_metrics = evaluate_model(
            experiment_folder=experiment_folder,
            model=model,
            dataset=eval_dataset,
            embedding_fn=embedding_fn,
            save_visualizations=False,
            save_embeddings=False,
            report_bbbc021_metrics=report_bbbc021_metrics,
            prefix='eval_',
        )
        metrics = {**train_metrics, **eval_metrics}

        print(f"Metrics: {metrics}")
        print()
        end_metrics_time = time.time()

        current_time = time.strftime("%H:%M:%S", time.gmtime())
        log_epoch(experiment_folder, {
            'epoch': epoch,
            'time': current_time,
            'train_time': end_epoch_time - start_epoch_time,
            'eval_time': end_metrics_time - end_epoch_time,
            'loss': loss,
            **metrics,
        })

        for metric, performance in metrics.items():
            if metric == early_stopping_metric and performance > best_metric_performance:
                save_model(experiment_folder, model, epoch, performance, metric)
                best_metric_performance = performance
        if loss < best_loss:
            save_model(experiment_folder, model, epoch, loss, 'loss')
            best_loss = loss

    # early stopping
    if early_stopping_metric is not None:
        model = load_model(model, experiment_folder, early_stopping_metric)
    save_model(experiment_folder, model, n_epochs, loss, 'end')

    return model


