import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm
from data.qnli.model import model_dict
from data.utils import generate_mask_hash, eval_loop, set_seed
from typing import List, Optional
import os

from transformers import default_data_collator
from data.utils import process_batch

# Load QNLI testset for cache
from data.qnli import QNLI
testset = QNLI(train=False)

train_params = {
    'bert': {
        'loss_fn': nn.CrossEntropyLoss(),
        'optimizer': 'adam',
        'scheduler': 'linear',
        'lr': 2e-5,
        'num_epochs': 3,
        'weight_decay': 0,
        'batch_size': 512,
        'testset': testset,
    },
}

def create_train_fn(model_name: str):
    def train_fn(dataset, weights=None, device='cuda', use_model_cache=False, verbose=False, return_model_dir=False):
        return train_model(model_name, dataset, weights, device=device, use_model_cache=use_model_cache,
                           verbose=verbose, return_model_dir=return_model_dir, **train_params[model_name])
    return train_fn

train_fns = {model_name: create_train_fn(model_name) for model_name in train_params.keys()}

def train_model(model_name: str,
                dataset: Dataset,
                weights: Optional[List[bool]] = None,
                device: str = 'cuda',
                use_model_cache: bool = False,
                loss_fn: nn.Module = nn.CrossEntropyLoss(),
                lr: float = 0.1,
                optimizer: str = 'adam',
                scheduler: str = 'onecycle',
                momentum: float = 0.9,
                num_epochs: int = 2,
                weight_decay: float = 1e-3,
                batch_size: int = 256,
                verbose: bool = False,
                return_losses: Optional[bool] = False,
                return_model_dir: Optional[bool] = False,
                seed: Optional[int] = None,
                testset: Optional[Dataset] = None) -> nn.Module:
    
    """
    Train the model using the given data loaders and training parameters.
    :param dataset: the dataset to train the model on
    :param model: the model to train
    :param epochs: the number of epochs to train the model
    :param batch_size: the batch size for training
    :param max_lr: the maximum learning rate for the OneCycleLR scheduler
    :param weight_decay: the L2 regularization strength
    :param optimizer: the optimizer to use for training
    :return: a list of dictionaries containing the model's performance at each epoch
    """
    if use_model_cache:
        if weights is None:
            weights = [1] * len(dataset)
        mask_hash = generate_mask_hash(weights)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'saved_models', model_name, mask_hash)
        model_path = os.path.join(model_dir, 'model.pt')
        if os.path.exists(model_path):
            if verbose:
                print("Model exists for provided training mask, loading from", model_dir)
            model = torch.load(model_path)
            if return_model_dir:
                return model, model_dir
            else:
                return model

    if seed is not None:
        set_seed(seed)
    model = model_dict[model_name]().to(device)
    model.train()

    print("device", device)
    print("learning rate", lr)
    # Initialize
    torch.cuda.empty_cache()
    total_data_size = len(dataset)
    if weights is not None:
        if sum(weights) == 0: return model
        indices = [index for index, value in enumerate(weights) if (value == 1)]
        dataset = Subset(dataset, indices)
        # Scale epochs by the fraction of data used
        num_epochs = int(num_epochs * (total_data_size / len(dataset)))
        print(f"Scaled epochs to {num_epochs}")
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)
    if testset is not None:
        test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)

    # Optimizer
    if optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=lr,
                            weight_decay=weight_decay)
    elif optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=lr,
                              momentum=momentum,
                              weight_decay=weight_decay)
    
    if verbose:
        print("Training Data Subset Size = ", len(train_loader.dataset), "/" ,  total_data_size)

    # Training
    losses, test_losses = [], []
    if scheduler == 'onecycle':
        sched = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_loader))
    elif scheduler == 'linear':
        sched = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_epochs*len(train_loader))
    elif scheduler == 'none':
        pass
    else:
        raise ValueError(f"Invalid scheduler: {scheduler}")

    for epoch in tqdm(range(num_epochs), desc='Training'):
        epoch_loss = 0.0
        for batch in tqdm(train_loader, desc='Batch', leave=False):
            inputs, labels = process_batch(batch, device)
            # Compute the loss
            outputs = model(**inputs)
            loss = loss_fn(outputs, labels)
            epoch_loss += loss.item() * labels.size(0)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler != 'none':
                sched.step()

        losses.append(epoch_loss / len(train_loader.dataset))

        if testset is not None:
            test_losses.append(eval_loop(model, test_loader, loss_fn, device))
            model.train()

        # if verbose and num_epochs > 10 and (epoch+1) % int(num_epochs / 10) == 0:
        if verbose:
            loss_str = f', Loss: {loss.item():.4f}, Test Loss: {test_losses[-1]:.4f}' if testset is not None else f', Loss: {loss.item():.4f}'
            print(f'Epoch [{epoch+1}/{num_epochs}]{loss_str}')
            # Get train acc and test acc
            acc_fn = lambda outputs, y: (outputs.argmax(dim=-1) == y).float().mean()
            accuracy = eval_loop(model, train_loader, acc_fn, device)
            acc_str = f'Train accuracy: {100 * accuracy:.2f}%'
            if testset is not None:
                test_accuracy = eval_loop(model, test_loader, acc_fn, device)
                acc_str += f', Test accuracy: {100 * test_accuracy:.2f}%'
            print(acc_str)
    
    # Get accuracy
    acc_fn = lambda outputs, y: (outputs.argmax(dim=-1) == y).float().mean()
    accuracy = eval_loop(model, train_loader, acc_fn, device)
    if verbose: print(f'Train accuracy: {100 * accuracy:.2f}%')
    if testset is not None:
        test_accuracy = eval_loop(model, test_loader, acc_fn, device)
        if verbose: print(f'Test accuracy: {100 * test_accuracy:.2f}%')

    losses, test_losses = torch.tensor(losses), torch.tensor(test_losses)
    if use_model_cache:
        os.makedirs(model_dir, exist_ok=True)
        torch.save(model, model_path)
        torch.save(torch.tensor(weights).bool(), os.path.join(model_dir, 'training_mask.pt'))
        torch.save(losses, os.path.join(model_dir, 'losses.pt'))
        torch.save(test_losses, os.path.join(model_dir, 'test_losses.pt'))
        # Plot losses and save
        import matplotlib.pyplot as plt
        plt.plot(losses, label=f'Train Loss (Acc = {100*accuracy:.2f}%)')
        if testset is not None:
            plt.plot(test_losses, label=f'Test Loss (Acc = {100*test_accuracy:.2f}%)')
        plt.legend()
        plt.title(f'QNLI Noisy {model_name.upper()} Losses')
        plt.savefig(os.path.join(model_dir, 'losses.png'))

    if return_losses:
        return (losses, test_losses)
    elif return_model_dir:
        return model, model_dir
    else:
        return model