import argparse
import os

import numpy as np
import torch
import xgboost as xgb
from scipy.stats import kendalltau
from sklearn.metrics import r2_score

from common.data_utils import SmallMoeZooDatasetAugmented, SmallMoEZooDataset
from common.weight_space import (
    MoEWeightSpaceFeatures,
    LinearWeightSpaceFeatures,
)


def flatten_input(embedding, classifier, encoder):
    """
    Flattens the embedding, classifier, and encoder into a 1D array for LightGBM.
    """
    classifier_features = []
    for weight, bias in classifier:
        classifier_features.append(np.array(weight, dtype=np.float32).flatten())  # Convert to numpy and flatten
        classifier_features.append(np.array(bias, dtype=np.float32).flatten())    # Convert to numpy and flatten
    classifier_flattened = np.concatenate(classifier_features)  # Concatenate all flattened classifier features

    encoder_features = []
    for W_q, W_k, W_v, W_o, W_G, W_A, W_B, b_G, b_A, b_B in encoder:
        encoder_features.extend([
            np.array(W_q, dtype=np.float32).flatten(),
            np.array(W_k, dtype=np.float32).flatten(),
            np.array(W_v, dtype=np.float32).flatten(),
            np.array(W_o, dtype=np.float32).flatten(),
            np.array(W_G, dtype=np.float32).flatten(),
            np.array(W_A, dtype=np.float32).flatten(),
            np.array(W_B, dtype=np.float32).flatten(),
            np.array(b_G, dtype=np.float32).flatten(),
            np.array(b_A, dtype=np.float32).flatten(),
            np.array(b_B, dtype=np.float32).flatten(),
        ])
    encoder_flattened = np.concatenate(encoder_features)  # Concatenate all flattened encoder features
    if embedding != None:
        embedding_weight = np.concatenate([np.array(w, dtype=np.float32).flatten() for w in embedding['weight']])
        embedding_bias = np.concatenate([np.array(b, dtype=np.float32).flatten() for b in embedding['bias']])

        # Concatenate all features into a single 1D array
        features = np.concatenate([embedding_weight, embedding_bias, classifier_flattened, encoder_flattened])
    else:
        features = np.concatenate([classifier_flattened, encoder_flattened])

    return features


def prepare_data(dataset, data_name='ag_news'):
    """
    Prepares the data for LightGBM by flattening all input tensors without using a DataLoader.
    """
    X, y = [], []

    for i in range(len(dataset)):
        sample = dataset[i]

        embedding = sample["embedding"]
        classifier = LinearWeightSpaceFeatures(sample["classifier"]['weight'], sample["classifier"]['bias']).to("cpu")
        encoder = MoEWeightSpaceFeatures(**sample["encoder"]).to("cpu")
        true_acc = sample["accuracy"]

        # Flatten and append
        if data_name == 'ag_news':
            X.append(flatten_input(None, classifier, encoder))
        else:
            X.append(flatten_input(embedding, classifier, encoder))

        y.append(true_acc[0])

    # Convert to numpy arrays
    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.float32)
    return X, y


def train_xgb(X_train, y_train, X_val, y_val, X_test, y_test, X_augment, y_augment, mode='gbtree', seed=3):
    """
    Trains an XGBoost model using the prepared dataset with early stopping.
    """
    # Set up the DMatrix for XGBoost
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dval = xgb.DMatrix(X_val, label=y_val)
    dtest = xgb.DMatrix(X_test, label=y_test)
    daugment = xgb.DMatrix(X_augment, label=y_augment)

    # Set XGBoost parameters
    params = {
        'nthread': 4,
        'objective': 'reg:squarederror',
        'eval_metric': 'rmse',
        'learning_rate': 0.1,
        'max_depth': 10, #10
        'min_child_weight': 50, #50
        'subsample': 0.8,
        'colsample_bytree': 0.8,
        'lambda': 0.01,
        'alpha': 0.02,
        'seed': seed,
        'tree_method': 'hist',
        'device': "cpu"#'cuda' if torch.cuda.is_available() else 'cpu'
    }

    # Custom evaluation function for logging metrics
    def log_metrics(epoch, model, dtrain, dval, dtest, daugment):
        train_pred = model.predict(dtrain)
        val_pred = model.predict(dval)
        test_pred = model.predict(dtest)
        augment_pred = model.predict(daugment)

        train_r2 = r2_score(y_train, train_pred)
        train_tau, _ = kendalltau(y_train, train_pred)
        val_r2 = r2_score(y_val, val_pred)
        val_tau, _ = kendalltau(y_val, val_pred)

        # Calculate R2 and Kendall Tau for test set
        test_r2 = r2_score(y_test, test_pred)
        test_tau, _ = kendalltau(y_test, test_pred)
        
        # Calculate R2 and Kendall Tau for augmented set
        augment_r2 = r2_score(y_augment, augment_pred)
        augment_tau, _ = kendalltau(y_augment, augment_pred)


        # Log metrics to console
        print(f"Epoch {epoch}: Train R2 = {train_r2:.4f}, Train Tau = {train_tau:.4f}, "
              f"Val R2 = {val_r2:.4f}, Val Tau = {val_tau:.4f}, Test R2 = {test_r2:.4f}, Test Tau = {test_tau:.4f}, "
              f"Augment R2 = {augment_r2:.4f}, Augment Tau = {augment_tau:.4f}")

    # Train the model with early stopping
    best_model = None
    best_val_rmse = float("inf")

    for epoch in range(50):
        model = xgb.train(params, dtrain, num_boost_round=epoch + 1, evals=[(dval, 'validation')],
                          early_stopping_rounds=50, verbose_eval=False)

        val_rmse = model.best_score
        log_metrics(epoch, model, dtrain, dval, dtest, daugment)

        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            best_model = model

    return best_model


def main(args):
    print("Start to load dataset")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Load the entire dataset directly
    train_set = SmallMoEZooDataset(data_path=args.data_path, n_heads=args.n_heads, split="train", name=args.dataset, cut_off=args.cut_off)
    val_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="val", name=args.dataset, cut_off=args.cut_off, augment_factor=args.augment_factor)
    test_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="test", name=args.dataset, cut_off=args.cut_off,augment_factor=args.augment_factor)
    augment_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="test", name = args.dataset, 
                                                cut_off=args.cut_off, augment_factor=args.augment_factor, keep_original=args.keep_original)


    # Prepare data for XGBoost
    print("Preparing data for XGBoost...")
    X_train, y_train = prepare_data(train_set, args.dataset)
    X_val, y_val = prepare_data(val_set, args.dataset)
    X_test, y_test = prepare_data(test_set, args.dataset)
    X_augment, y_augment = prepare_data(augment_set, args.dataset)
    print("Feature shape:", X_train.shape)
    print("Memory usage (MB):", X_train.nbytes / 1e6)
    # Train XGBoost model
    print("Training XGBoost model...")
    xgb_model = train_xgb(X_train, y_train, X_val, y_val, X_test, y_test, X_augment, y_augment, mode=args.model, seed=args.seed)

    dtest = xgb.DMatrix(X_test, label=y_test)

    # y_pred = xgb_model.predict(dtest)
    # results = np.array([y_pred, y_test])
    # output_file = f'{args.model}_{args.dataset}.npy'
    # np.save(output_file, results)
    # print(f"Results saved to {output_file}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='LightGBM training from Transformer data')

    # Training Arguments
    parser.add_argument('--seed', type=int, default=3, help='random seed')
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--model', type=str, default="gbtree", choices=['gbtree', "dart"], help="model type")

    # Data arguments
    parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', "ag_news"], help='dataset to use')
    parser.add_argument('--n_heads', type=int, default=2, help='Number of heads in transformer input network')
    parser.add_argument('--data_path', type=str, default='nfn-dataset', help='path to dataset')
    parser.add_argument('--cut_off', type=float, default=0.0, help='cut off rate for accuracy')
    parser.add_argument('--augment_factor', type=int, default=1, help='Augment factor of dataset')
    parser.add_argument('--keep_original', type=int, default=1, help='Choose to keep the original data or not')

    args = parser.parse_args()

    # Set random seed for reproducibility
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(args.seed)

    main(args)