import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm
from data.mnist.model import model_dict
from data.utils import generate_mask_hash, eval_loop
from typing import List, Optional
import os
from tqdm.auto import tqdm

# Load MNIST testset for cache (same as mnist_noisy since no labels are flipped in the testset)
from data.mnist import MNIST
testset = MNIST(train=False, download=True)

train_params = {
    'ann_l': {
        'loss_fn': nn.CrossEntropyLoss(),
        'optimizer': 'adam',
        'scheduler': 'onecycle',
        'lr': 0.001,
        'num_epochs': 50,
        'weight_decay': 0,
        'batch_size': 512,
        'testset': testset,
    }
}

def create_train_fn(model_name: str):
    def train_fn(dataset, weights=None, device='cuda', use_model_cache=False, verbose=False):
        return train_model(model_name, dataset, weights, device=device, use_model_cache=use_model_cache, verbose=verbose, **train_params[model_name])
    return train_fn

train_fns = {model_name: create_train_fn(model_name) for model_name in train_params.keys()}


def train_model(model_name: str,
                dataset: Dataset,
                weights: Optional[List[bool]] = None,
                device: str = 'cpu',
                use_model_cache: Optional[bool] = False,
                loss_fn: Optional[nn.Module] = nn.CrossEntropyLoss(),
                lr: Optional[float] = 0.1,
                optimizer: Optional[str] = 'adam',
                scheduler: Optional[str] = 'onecycle',
                momentum: Optional[float] = 0.9,
                num_epochs: Optional[int] = 50,
                weight_decay: Optional[float] = 1e-3,
                batch_size: Optional[int] = 128,
                verbose: Optional[bool] = False,
                return_losses: Optional[bool] = False,
                testset: Optional[Dataset] = None) -> nn.Module:
    """
    Train the model using the given data loaders and training parameters.
    :param dataset: the dataset to train the model on
    :param model: the model to train
    :param epochs: the number of epochs to train the model
    :param batch_size: the batch size for training
    :param max_lr: the maximum learning rate for the OneCycleLR scheduler
    :param weight_decay: the L2 regularization strength
    :param optimizer: the optimizer to use for training
    :return: a list of dictionaries containing the model's performance at each epoch
    """
    if use_model_cache:
        if weights is None:
            weights = [1] * len(dataset)
        mask_hash = generate_mask_hash(weights)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'saved_models', model_name, mask_hash)
        model_path = os.path.join(model_dir, 'model.pt')
        if os.path.exists(model_path):
            if verbose:
                print("Model exists for provided training mask, loading from", model_dir)
            return torch.load(model_path)

    model = model_dict[model_name]().to(device)
    model.train()

    # Initialize
    torch.cuda.empty_cache()
    total_data_size = len(dataset)
    if weights is not None:
        if verbose: print(f"Getting data subset")
        if sum(weights) == 0: return model
        indices = [index for index, value in enumerate(weights) if (value == 1)]
        dataset = Subset(dataset, indices)
        # Scale epochs by the fraction of data used
        num_epochs = int(num_epochs * (total_data_size / len(dataset)))
        if verbose: print(f"Scaled epochs to {num_epochs}")
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Optimizer
    if optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=lr,
                            weight_decay=weight_decay)
    elif optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=lr,
                              momentum=momentum,
                              weight_decay=weight_decay)

    if verbose:
        print("Training Data Subset Size = ", len(dataloader.dataset), "/" ,  total_data_size)

    # Training
    losses, test_losses = [], []
    if scheduler == 'onecycle':
        sched = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(dataloader))
    elif scheduler == 'step':
        step_size = num_epochs // 4
        sched = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.3)
    elif scheduler == 'exponential':
        sched = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    else:
        raise ValueError(f"Invalid scheduler: {scheduler}")
    # for epoch in range(num_epochs):
    for epoch in tqdm(range(num_epochs), desc="Training Epochs", disable=not verbose):
        epoch_loss = 0.0
        for (x, y) in dataloader:
            x, y = x.to(device), y.to(device)
            # Compute the loss
            outputs = model(x.float())
            loss = loss_fn(outputs, y.long())
            epoch_loss += loss.item() * x.size(0)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler == 'onecycle':
                sched.step()
        if scheduler in ['step', 'exponential']:
            sched.step()

        if num_epochs > 10 and (epoch+1) % (num_epochs // 10) == 0:
            losses.append(epoch_loss / len(dataloader.dataset))
            if testset is not None:
                test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
                test_losses.append(eval_loop(model, test_loader, loss_fn, device))
                model.train()

        if verbose and num_epochs > 10 and (epoch+1) % (num_epochs // 10) == 0:
            test_str = f', Test Loss: {test_losses[-1]:.4f}' if testset is not None else ''
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {losses[-1]:.4f}{test_str}')

    losses, test_losses = torch.tensor(losses), torch.tensor(test_losses)
    if use_model_cache:
        os.makedirs(model_dir, exist_ok=True)
        torch.save(model, model_path)
        torch.save(torch.tensor(weights).bool(), os.path.join(model_dir, 'training_mask.pt'))
        torch.save(losses, os.path.join(model_dir, 'losses.pt'))
        torch.save(test_losses, os.path.join(model_dir, 'test_losses.pt'))

    # Get accuracy
    if verbose:
        train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        acc_fn = lambda outputs, y: (outputs.argmax(dim=-1) == y).float().mean()
        accuracy = eval_loop(model, train_loader, acc_fn, device)
        print(f'Train accuracy: {100 * accuracy:.2f}%')
        if testset is not None:
            test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
            accuracy = eval_loop(model, test_loader, acc_fn, device)
            print(f'Test accuracy: {100 * accuracy:.2f}%')

    return (losses, test_losses) if return_losses else model