import argparse
import os

import numpy as np
import torch
from scipy.stats import kendalltau
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVR
from sklearn.pipeline import make_pipeline

from common.data_utils import SmallMoEZooDataset, SmallMoeZooDatasetAugmented
from common.weight_space import (
    MoEWeightSpaceFeatures,
    LinearWeightSpaceFeatures,
)
from sklearn.decomposition import PCA
def flatten_input(embedding, classifier, encoder):
    classifier_features = []
    for weight, bias in classifier:
        classifier_features.append(np.array(weight, dtype=np.float32).flatten())
        classifier_features.append(np.array(bias, dtype=np.float32).flatten())
    classifier_flattened = np.concatenate(classifier_features)

    encoder_features = []
    for W_q, W_k, W_v, W_o, W_G, W_A, W_B, b_G, b_A, b_B in encoder:
        encoder_features.extend([
            np.array(W_q, dtype=np.float32).flatten(),
            np.array(W_k, dtype=np.float32).flatten(),
            np.array(W_v, dtype=np.float32).flatten(),
            np.array(W_o, dtype=np.float32).flatten(),
            np.array(W_G, dtype=np.float32).flatten(),
            np.array(W_A, dtype=np.float32).flatten(),
            np.array(W_B, dtype=np.float32).flatten(),
            np.array(b_G, dtype=np.float32).flatten(),
            np.array(b_A, dtype=np.float32).flatten(),
            np.array(b_B, dtype=np.float32).flatten(),
        ])
    encoder_flattened = np.concatenate(encoder_features)

    if embedding is not None:
        embedding_weight = np.concatenate([np.array(w, dtype=np.float32).flatten() for w in embedding['weight']])
        embedding_bias = np.concatenate([np.array(b, dtype=np.float32).flatten() for b in embedding['bias']])
        features = np.concatenate([embedding_weight, embedding_bias, classifier_flattened, encoder_flattened])
    else:
        features = np.concatenate([classifier_flattened, encoder_flattened])

    return features

def prepare_data(dataset, data_name='ag_news'):
    X, y = [], []
    for i in range(len(dataset)):
        sample = dataset[i]
        embedding = sample["embedding"]
        classifier = LinearWeightSpaceFeatures(sample["classifier"]['weight'], sample["classifier"]['bias']).to("cpu")
        encoder = MoEWeightSpaceFeatures(**sample["encoder"]).to("cpu")
        true_acc = sample["accuracy"]

        if data_name == 'ag_news':
            X.append(flatten_input(None, classifier, encoder))
        else:
            X.append(flatten_input(embedding, classifier, encoder))

        y.append(true_acc[0])

    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)

def train_svm(X_train, y_train, X_val, y_val, X_test, y_test, X_augment, y_augment, n_components=100):
    model = make_pipeline(
        StandardScaler(),
        PCA(n_components=n_components),
        LinearSVR( random_state=42, max_iter=10000, verbose=1)
    )

    model.fit(X_train, y_train)

    def eval_model(X, y):
        y_pred = model.predict(X)
        r2 = r2_score(y, y_pred)
        tau, _ = kendalltau(y, y_pred)
        return r2, tau

    train_r2, train_tau = eval_model(X_train, y_train)
    val_r2, val_tau = eval_model(X_val, y_val)
    test_r2, test_tau = eval_model(X_test, y_test)
    augment_r2, augment_tau = eval_model(X_augment, y_augment)

    print(f"Train R2 = {train_r2:.4f}, Tau = {train_tau:.4f}")
    print(f"Val   R2 = {val_r2:.4f}, Tau = {val_tau:.4f}")
    print(f"Test  R2 = {test_r2:.4f}, Tau = {test_tau:.4f}")
    print(f"Augment R2 = {augment_r2:.4f}, Tau = {augment_tau:.4f}")

    return model

def main(args):
    print("Start to load dataset")
    train_set = SmallMoEZooDataset(data_path=args.data_path, n_heads=args.n_heads, split="train", name=args.dataset, cut_off=args.cut_off)
    # val_set = SmallMoEZooDataset(data_path=args.data_path, n_heads=args.n_heads, split="val", name=args.dataset, cut_off=args.cut_off)
    # test_set = SmallMoEZooDataset(data_path=args.data_path, n_heads=args.n_heads, split="test", name=args.dataset, cut_off=args.cut_off)
    #train_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="train", name=args.dataset, cut_off=args.cut_off, augment_factor= args.augment_factor)
    val_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="val", name=args.dataset, cut_off=args.cut_off,augment_factor= 1)
    test_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="test", name=args.dataset, cut_off=args.cut_off, augment_factor= 1)
    augment_set = SmallMoeZooDatasetAugmented(data_path=args.data_path, n_heads=args.n_heads, split="test", name = args.dataset, 
                                                cut_off=args.cut_off, augment_factor=args.augment_factor, keep_original=args.keep_original)


    print("Preparing data for SVM...")
    X_train, y_train = prepare_data(train_set, args.dataset)
    X_val, y_val = prepare_data(val_set, args.dataset)
    X_test, y_test = prepare_data(test_set, args.dataset)
    X_augment, y_augment = prepare_data(augment_set, args.dataset)

    print("Training SVM model...")
    train_svm(X_train, y_train, X_val, y_val, X_test, y_test, X_augment, y_augment, n_components = args.n_components)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SVM training from Transformer data')
    parser.add_argument('--seed', type=int, default=3)
    parser.add_argument('--n_components', type=int, default=1000)
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', "ag_news"])
    parser.add_argument('--n_heads', type=int, default=2)
    parser.add_argument('--data_path', type=str, default='nfn-dataset')
    parser.add_argument('--cut_off', type=float, default=0.1)
    parser.add_argument('--project', type=str, default=None)
    parser.add_argument('--entity', type=str, default=None)
    parser.add_argument('--augment_factor', type=int, default=1)
    parser.add_argument('--keep_original', type=int, default=1)

    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    main(args)
