import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import numpy as np
from omegaconf import OmegaConf
config = OmegaConf.load('config.yaml')
cuda = config.cuda


def train(model, train_loader, device):
    model.train()
    running_loss = 0
    for batch_samples in train_loader:
        train_batch = batch_samples[0]
        train_batch = train_batch.to(device)
        model.optimizer.zero_grad()
        outputs, latent_vecs = model(train_batch)
        loss = model.criterion(outputs, train_batch)
        loss.backward()
        model.optimizer.step()
        running_loss += loss.item()
    running_loss /= len(train_loader)
    return running_loss, latent_vecs


def val(model, val_loader, device):
    model.eval()
    with torch.no_grad():
        if len(val_loader) == 2 and type(val_loader) == list:  # scenario is 'anomalies'
            regular_loss = 0
            anomaly_loss = 0
            roc_auc = 0
            mse_no_reduction = nn.MSELoss(reduction='none')
            val_loader_regular = val_loader[0]
            val_loader_anomaly = val_loader[1]
            for regular_batch, anomaly_batch in zip(val_loader_regular, val_loader_anomaly):
                regular = regular_batch[0].to(device)
                regular_labels = regular_batch[1]
                anomaly = anomaly_batch[0].to(device)
                anomaly_labels = anomaly_batch[1]
                labels = torch.cat((regular_labels, anomaly_labels))
                outputs_regular, _ = model(regular)
                outputs_anomaly, _ = model(anomaly)
                loss_regular = model.criterion(outputs_regular, regular)
                loss_anomaly = model.criterion(outputs_anomaly, anomaly)
                outputs_regular_values = torch.mean(mse_no_reduction(outputs_regular, regular), dim=1)
                outputs_anomaly_values = torch.mean(mse_no_reduction(outputs_anomaly, anomaly), dim=1)
                outputs_values = torch.cat((outputs_regular_values, outputs_anomaly_values))
                roc_auc += roc_auc_score(labels.numpy(), outputs_values.cpu().numpy())
                regular_loss += loss_regular.item()
                anomaly_loss += loss_anomaly.item()
        else:  # scenario is not 'anomalies' (i.e. sample_noise, feature_noise and domain_shift)
            running_loss = 0
            for batch_samples in val_loader:
                val_batch = batch_samples[0].to(device)
                outputs, latent_vecs = model(val_batch)
                loss = model.criterion(outputs, val_batch)
                running_loss += loss.item()
    if len(val_loader) == 2 and type(val_loader) == list:
        regular_loss /= len(val_loader_regular)
        anomaly_loss /= len(val_loader_anomaly)
        roc_auc /= len(val_loader_regular)
        return regular_loss, anomaly_loss, roc_auc
    else:
        running_loss /= len(val_loader)
        return running_loss, latent_vecs


def train_and_eval(train_loader, val_loader, test_loader, model, opt, lr, epochs, reg, model_check_point, test=True):
    device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu")
    if opt == 'SGD':
        model.optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=reg)  # TODO: add momentum
    elif opt == 'Adam':
        model.optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=reg)
    else:
        raise Exception('Optimizer should be "Adam" or "SGD"')

    model.to(device)
    model.criterion = nn.MSELoss()
    # decayRate = 0.96
    # lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=model.optimizer, gamma=decayRate, verbose=True)
    # now_str = str(datetime.datetime.now()).replace(':', '_').replace('-', '_').replace(' ', '_').split('.')[0]
    # os.makedirs(f'models/{now_str}')
    # mlflow.log_param("now str", now_str)
    for epoch in tqdm(range(epochs)):
        train_loss = train(model, train_loader, device)
        #mlflow.log_metric('Train loss', train_loss, epoch+1)
        val_loss, _ = val(model, val_loader, device)
        #mlflow.log_metric('Val loss', val_loss, epoch+1)
        train_as_val_loss, _ = val(model, train_loader, device)
        #mlflow.log_metric('train as val loss', train_as_val_loss, epoch + 1)
        # lr_scheduler.step()
        if (epoch + 1) % 10 == 0:
            print(f'Epoch: {epoch+1}, Train loss: {train_loss}, Val loss: {val_loss}')

        # if (epoch + 1) % model_check_point == 0:
        #     filename = f'epoch_{epoch+1}_train_loss_{round(train_loss, 3)}_val_loss_{round(val_loss, 3)}.pt'
        #     torch.save({'epoch': epoch + 1,
        #                 'model_state_dict': model.state_dict(),
        #                 'optimizer_state_dict': model.optimizer.state_dict(),
        #                 'train_loss': train_loss,
        #                 'val_loss': val_loss},
        #                f"./models/{now_str}/" + filename)
        #     mlflow.log_artifacts('models/' + now_str)

    if test:
        test_loss, outputs = val(model, test_loader, device)
        #mlflow.log_metric('Test loss', test_loss, epoch+1)
        test_outputs = np.array([sample for batch in outputs for sample in batch.detach().cpu().numpy()])
        return train_loss, test_loss, test_outputs
    return train_loss


def train_and_eval_per_epoch(train_loader, test_loader, model, opt, lr):
    device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu")
    if opt == 'SGD':
        model.optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    elif opt == 'Adam':
        model.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        raise Exception('Optimizer should be "Adam" or "SGD"')

    model.to(device)
    model.criterion = nn.MSELoss()
    train_loss, latent_vecs_train = train(model, train_loader, device)
    test_loss, latent_vecs_test = val(model, test_loader, device)

    return train_loss, test_loss, latent_vecs_train, latent_vecs_test


def train_and_eval_per_epoch_anomalies(train_loader, test_regular_loader, test_anomalies_loader, model, opt, lr):
    device = torch.device(f"cuda:{cuda}" if torch.cuda.is_available() else "cpu")
    if opt == 'SGD':
        model.optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    elif opt == 'Adam':
        model.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        raise Exception('Optimizer should be "Adam" or "SGD"')

    model.to(device)
    model.criterion = nn.MSELoss()
    train_loss, _ = train(model, train_loader, device)
    test_loader = [test_regular_loader, test_anomalies_loader]
    regular_loss, anomaly_loss, roc_auc = val(model, test_loader, device)

    return train_loss, regular_loss, anomaly_loss, roc_auc

