import os.path
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.applications import ResNet50  # Replace with your desired model
from tensorflow.keras.applications.resnet50 import preprocess_input

import cfg


def do(df, x_data, y_data, tag_col, args):
    # unpack variables
    (dataset, dataset_type, random_seed, init_train_samples, add_train_samples, max_train_samples, train_paradigm,
     weight_init, training, qm, fname_result) = args

    # create model
    model = create_model(x_data, y_data, weight_init, training, dataset, dataset_type, df, tag_col)

    # get labeled/unlabeled data
    x_labeled, y_labeled, x_unlabeled = get_training_data(x_data, y_data, df, tag_col, train_paradigm)

    # train model
    train_model(model, dataset, x_labeled, y_labeled, x_unlabeled, train_paradigm, training)

    return model


def create_model(x_data, y_data, weight_init, training, dataset, dataset_type, df, tag_col):
    # define input and output properties
    num_outputs = y_data.shape[1]
    input_shape = np.shape(x_data[0])

    # check which loss and activation shall be used (multiclass vs. multi-label)
    if dataset in ['carina', 'mscoco', 'reuters', 'scene']:
        loss = 'binary_crossentropy'
        activation = 'sigmoid'
    elif dataset in ['cifar10', 'urbansound8k', 'agnews', 'letter']:
        loss = 'categorical_crossentropy'
        activation = 'softmax'

    # Shallow classifier if embeddings are used
    if x_data.ndim == 2:
        model = get_shallow_classifier(num_outputs, input_shape, loss, activation)
    # More complex models specific to data domain
    elif dataset in ['mscoco', 'cifar10']:
        if weight_init in ['tl', 'random']:
            # define pre-trained model weights
            weights = None
            if weight_init == 'tl':
                weights = 'imagenet'

            # create model
            model = get_resnet_18(weights, num_outputs, training, activation, loss)

        # self supervised weight initialization
        elif weight_init == 'self-sl':
            # create model name
            fmodel = f'ResNet50_{dataset}_selfsl.keras'

            # if model exists, load it, otherwise train it
            if os.path.exists(Path(cfg.path_model, fmodel)):
                model = tf.keras.models.load_model(Path(cfg.path_model, fmodel))
            else:
                print(f'START TRAINING SELF-SUPERVISED MODEL: {dataset} {dataset_type}')

                # load data (all NON-Evaluation data)
                df_all = pd.read_csv(Path(cfg.path_data, dataset, 'metadata.csv'))
                x_data_all = np.load(Path(cfg.path_data, dataset, 'data.npy'))
                x_ssl = x_data_all[df_all['subset'] == cfg.tag_unlabelled]

                # get nr of samples
                nr_samples = min(50000, len(x_ssl))
                x_ssl = x_ssl[:nr_samples]

                # init rotated samples
                x_rotated = np.zeros((4 * nr_samples, 224, 224, 3), dtype=np.float32)
                y_rotated = np.zeros((4 * nr_samples, 4))

                # iterate over samples
                for i in range(nr_samples):
                    print(f'Rotate sample {i} / {nr_samples}')
                    # iterate over rotation positions
                    for k in range(4):  # Rotation 0, 90, 180, 270
                        idx = 4 * i + k
                        x_rotated[idx] = tf.image.rot90(x_ssl[i], k=k).numpy()
                        y_rotated[idx, k] = 1

                # define batch size
                batch_size = 64

                # define split indices for training and validation data (full batches)
                validation_split = 0.2
                index_split = int((len(x_rotated) * (1-validation_split) // batch_size) * batch_size)
                index_last_val = int(index_split + ((len(x_rotated) - index_split) // batch_size) * batch_size)

                # get training and validation data
                x_train_ds = x_rotated[:index_split]
                y_train_ds = y_rotated[:index_split]

                x_val_ds = x_rotated[index_split:index_last_val]
                y_val_ds = y_rotated[index_split:index_last_val]

                # Shuffle training set
                perm_train = np.random.permutation(len(x_train_ds))
                x_train_ds = x_train_ds[perm_train]
                y_train_ds = y_train_ds[perm_train]

                # Shuffle validation set
                perm_val = np.random.permutation(len(x_val_ds))
                x_val_ds = x_val_ds[perm_val]
                y_val_ds = y_val_ds[perm_val]

                # preprocess datasets
                x_train_ds = preprocess_input(x_train_ds)
                x_val_ds = preprocess_input(x_val_ds)

                # create training dataset
                train_dataset = tf.data.Dataset.from_generator(
                    lambda: data_generator(x_train_ds, y_train_ds, batch_size),
                    output_signature=(
                        tf.TensorSpec(shape=(batch_size, 224, 224, 3), dtype=tf.float32),
                        tf.TensorSpec(shape=(batch_size, 4), dtype=tf.float32)
                    )
                ).prefetch(tf.data.AUTOTUNE)

                # create validation dataset
                val_dataset = tf.data.Dataset.from_generator(
                    lambda: data_generator(x_val_ds, y_val_ds, batch_size),
                    output_signature=(
                        tf.TensorSpec(shape=(batch_size, 224, 224, 3), dtype=tf.float32),
                        tf.TensorSpec(shape=(batch_size, 4), dtype=tf.float32)
                    )
                ).prefetch(tf.data.AUTOTUNE)

                # define the number of training steps per epoche
                steps_per_epoch = len(x_train_ds) // batch_size
                validation_steps = len(x_val_ds) // batch_size

                # load model
                model = get_resnet_18(None, 4, 'finetune', 'softmax',
                                      'categorical_crossentropy')

                # train model
                early_stopping = tf.keras.callbacks.EarlyStopping(
                    monitor='val_loss',  # Metric to monitor for early stopping ('val_loss' or 'val_accuracy', etc.)
                    patience=10,  # Number of epochs with no improvement after which training will be stopped
                    restore_best_weights=True,
                    # Restores the model weights from the epoch with the best validation performance
                    start_from_epoch=300,  # Train for 50 epochs save, warmStart
                    min_delta=0.1,  # minimum improvement
                    mode='auto',  # stop when accuracy stops increasing
                    verbose=0
                )
                model.fit(train_dataset, epochs=500, steps_per_epoch=steps_per_epoch, validation_data=val_dataset,
                    validation_steps=validation_steps, callbacks=[early_stopping])

                # exchange classification head
                backbone = tf.keras.Model(inputs=model.input, outputs=model.layers[-2].output)
                x = tf.keras.layers.Dense(num_outputs, activation=activation)(backbone.output)
                model = tf.keras.Model(inputs=backbone.input, outputs=x)

                # make all layers trainable or freeze transfer learning model
                if training == 'frozen':
                    backbone.trainable = False
                else:
                    backbone.trainable = True

                # compile model
                model.compile(optimizer='Adam', loss=loss, metrics=['accuracy'])

                # save model
                cfg.path_model.mkdir(parents=True, exist_ok=True)
                model.save(Path(cfg.path_model, fmodel))

        else:
            raise ValueError('Weight Init for image datasets not defined')

    return model


def data_generator(x_data, y_data, batch_size):
    n_samples = len(x_data)
    while True:
        for offset in range(0, n_samples, batch_size):
            batch_x = x_data[offset:offset + batch_size]
            batch_y = y_data[offset:offset + batch_size]
            batch_x = preprocess_input(batch_x)
            yield batch_x, batch_y



def get_shallow_classifier(num_outputs, input_shape, loss, activation):
    # model: input layer and one output layer
    model = tf.keras.Sequential()
    model.add(tf.keras.Input(shape=input_shape))
    model.add(tf.keras.layers.Dense(num_outputs, activation=activation))
    model.compile(optimizer='Adam', loss=loss, metrics=['accuracy'])

    return model


def get_resnet_18(weights, num_outputs, training, activation, loss):
    # use ResNet-18, adjust classification layer
    original_model = ResNet50(weights=weights, include_top=False, pooling='avg')
    x = original_model.output
    x = tf.keras.layers.Dense(num_outputs, activation=activation)(x)
    model = tf.keras.Model(inputs=original_model.input, outputs=x)

    # train only last layer
    if training == 'frozen':
        original_model.trainable = False
    # train some layers
    elif 'last' in training:
        # get number of layers to train
        nr_layers_train = int(training.split('-')[-1])

        # get trainable layers
        trainable_layers = [layer for layer in model.layers if layer.trainable_weights]

        # freeze all layer weights
        for layer in model.layers:
            layer.trainable = False

        # unfreeze last n trainable layers
        for layer in trainable_layers[-nr_layers_train:]:
            layer.trainable = True

    # train all layers
    else:
        original_model.trainable = True

    # compile model
    model.compile(optimizer='Adam', loss=loss, metrics=['accuracy'])

    return model


def get_training_data(x_data, y_data, df, tag_col, train_paradigm):
    # get the label tags for current iteration
    array_tag = df[tag_col].to_numpy()

    # filter training samples
    x_train = x_data[array_tag == cfg.tag_train]
    y_train = y_data[array_tag == cfg.tag_train]

    # filter unlabelled samples only if necessary (might take long)
    x_unl = None
    if train_paradigm == 'semi-sl':
        x_unl = x_data[array_tag == cfg.tag_unlabelled]

    return x_train, y_train, x_unl


def train_model(model, dataset, x_labeled, y_labeled, x_unlabeled, train_paradigm, training):
    # define min and max number of epochs based on training paradigm
    if 'finetune' in training:
        min_epochs = 300
        max_epochs = 500
    else:
        min_epochs = 50
        max_epochs = 150

    # Define the EarlyStopping callback
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',  # Metric to monitor for early stopping ('val_loss' or 'val_accuracy', etc.)
        patience=10,  # Number of epochs with no improvement after which training will be stopped
        restore_best_weights=True,  # Restores the model weights from the epoch with the best validation performance
        start_from_epoch=min_epochs,  # Train for 50 epochs save, warmStart
        min_delta=0.1,  # minimum improvement
        mode='auto',  # stop when accuracy stops increasing
        verbose=0
    )

    # Supervised Learning
    if train_paradigm == 'sl':
        # Preprocess image data
        if dataset in ['mscoco', 'cifar10'] and x_labeled.ndim > 3:
            x_labeled = preprocess_input(x_labeled)

        # train model
        model.fit(x_labeled, y_labeled, epochs=max_epochs, batch_size=64, validation_split=0.2,
                  callbacks=[early_stopping], verbose=1)

    elif train_paradigm == 'semi-sl':
        # Preprocess image data
        if dataset in ['mscoco', 'cifar10'] and x_labeled.ndim > 3:
            x_labeled = preprocess_input(x_labeled)

        # train model
        model.fit(x_labeled, y_labeled, epochs=max_epochs, batch_size=64, validation_split=0.2,
                  callbacks=[early_stopping], verbose=1)

        # 1. Get predictions on unlabelled data
        y_unl_pred = model(x_unlabeled).numpy()

        # Pseudo-Labels: confident positives = 1, confident negatives = 0
        threshold = 0.9
        all_above = np.all(y_unl_pred > threshold, axis=1)
        all_below = np.all(y_unl_pred < (1 - threshold), axis=1)
        confident_rows = np.logical_or(all_above, all_below)
        confident_indices = np.where(confident_rows)[0]

        if confident_indices.size == 0:
            return

        # select unlabelled samples
        x_semisl = x_unlabeled[confident_indices]
        y_semisl = y_unl_pred[confident_indices, :]
        y_semisl[:] = (y_semisl >= 0.5).astype(np.float32)

        # combine pseudo labels and labels
        x_combined = np.concatenate([x_labeled, x_semisl], axis=0)
        y_combined = np.concatenate([y_labeled, y_semisl], axis=0)

        # train model
        model.fit(x_combined, y_combined, epochs=max_epochs, batch_size=64, validation_split=0.2,
                  callbacks=[early_stopping], verbose=0)

    elif train_paradigm == 'augmented':
        if dataset in ['mscoco', 'cifar10']:
            # get image augmentations
            x_train, y_train, x_val, y_val = augment_image_data(x_labeled, y_labeled)

            # preprocess data
            x_train = preprocess_input(x_train)
            x_val = preprocess_input(x_val)

            # train model
            model.fit(x_train, y_train, epochs=max_epochs, batch_size=64, validation_data=(x_val, y_val),
                      callbacks=[early_stopping], verbose=1)

        else:
            raise ValueError('Training: Augmentation for dataset domain not implemented')
    else:
        raise ValueError('Unknown training paradigm')


def augment_image_data(x_data, y_data, augment_factor=2, val_split=0.2):
    # get the number of samples
    num_samples = len(x_data)

    # get indices (augmented shall be in same set as original (train, val)
    indices = np.arange(num_samples)
    np.random.shuffle(indices)
    split_point = int(num_samples * (1 - val_split))
    train_indices = indices[:split_point]

    # init train and val set
    x_train_final, y_train_final = [], []
    x_val_final, y_val_final = [], []

    # iterate over number samples
    for i in range(num_samples):
        # get image and label
        img = x_data[i]
        label = y_data[i]

        # init list with images and labels
        imgs = [img]  # Originalbild
        labels = [label]

        # iterate over number augmentations
        for _ in range(augment_factor):
            # convert image to tensor for easy augmentation
            img_tf = tf.convert_to_tensor(img, dtype=tf.float32)

            # augment image
            img_tf = tf.image.random_flip_left_right(img_tf)
            img_tf = tf.image.random_crop(img_tf, size=[200, 200, 3])
            img_tf = tf.image.resize(img_tf, [224, 224])
            img_tf = tf.image.random_brightness(img_tf, max_delta=0.2)
            img_tf = tf.image.random_contrast(img_tf, lower=0.8, upper=1.2)
            img_tf = tf.image.random_saturation(img_tf, lower=0.8, upper=1.2)
            img_tf = tf.image.random_hue(img_tf, max_delta=0.1)
            noise = tf.random.normal(shape=tf.shape(img_tf), mean=0, stddev=20)
            img_tf = img_tf + noise
            img_tf = tf.clip_by_value(img_tf, 0.0, 255.0)

            # append image to batch with original and other augmented images
            imgs.append(img_tf.numpy())
            labels.append(label)

        # append image with augmentations to train or validation set
        if i in train_indices:
            x_train_final.extend(imgs)
            y_train_final.extend(labels)
        else:
            x_val_final.extend(imgs)
            y_val_final.extend(labels)

    # create numpy arrays
    x_train_final = np.array(x_train_final, dtype=np.float32)
    y_train_final = np.array(y_train_final, dtype=np.float32)
    x_val_final = np.array(x_val_final, dtype=np.float32)
    y_val_final = np.array(y_val_final, dtype=np.float32)

    return x_train_final, y_train_final, x_val_final, y_val_final
