import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from scipy import stats
from sklearn.mixture import GaussianMixture
from transformers import AutoModel, AutoTokenizer, AutoConfig
import argparse
from utils import apply_pooler, dotdict
import logging

from simcse.models import BertForCL
from stealing.steal_MLP import load_BERT_for_CL, get_sent_features, trainer_eval

parser = argparse.ArgumentParser(description='PyTorch SimCSE')
parser.add_argument('--model',
                    default="",
                    help='Path to Victim')
parser.add_argument('--train1',
                    default="",
                    help='Path to Train1 Data')
parser.add_argument('--train2',
                    default="",
                    help='Path to Train2 Data')
parser.add_argument('--test',
                    default="",
                    help='Path to Test Data')
parser.add_argument('--num_dims',
                    type=int,
                    # default=768,
                    default=128,
                    help='Dimensions of the Output')
parser.add_argument('--bs',
                    type=int,
                    default=50,
                    help='Batch Size')
parser.add_argument('--gmm_components',
                    type=int,
                    default=3,
                    help='How many components in the GMM')
parser.add_argument('--unconverted',
                    type=bool,
                    default=False,
                    help="Are we trying to load an unconverted model?")
parser.add_argument('--subset',
                    type=bool,
                    default=False,
                    help="When we use subset, we use all the DI data for train 1 and 2")
parser.add_argument('--basemodel',
                    # default='bert-base-uncased',
                    default='prajjwal1/bert-tiny',
                    help='What was the model under evaluation based on?',
                    choices=['bert-base-uncased', 'bert-large-uncased',
                             'roberta-base', 'roberta-base',
                             'prajjwal1/bert-tiny']
                    )
parser.add_argument('--use_pooler',
                    type=bool,
                    default=False,
                    help="Whether or not to use the pooler in sent_emb forward (victims use it, stolen models don't)")

def return_di_embeddings(model, sent_features, use_pooler, return_sent_emb=False):
    if return_sent_emb:
        embeddings = model(**sent_features,
                           output_hidden_states=True,
                           return_dict=True,
                           sent_emb=True,
                           use_pooler=use_pooler)
    else:
        embeddings = model(**sent_features,
                           output_hidden_states=True,
                           return_dict=True)
    pooler_args = dotdict()
    pooler_args.pooler = 'cls'
    embeddings = apply_pooler(outputs=embeddings,
                              batch=sent_features,
                              args=pooler_args)
    return embeddings


PROCESSING_METHOD = "standard"
#PROCESSING_METHOD = "evaluation"

def main():
    args = parser.parse_args()

    logname = 'training.log'
    logging.basicConfig(
        filename=logname,
        filemode='a',
        # level=logging.DEBUG,
        level=logging.INFO,
        force=True,  # automatically remove the root handlers
    )
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(f"Using pooler {args.use_pooler}.")


    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)

    args.seqlength = 32
    args.device = device

    tokenizer = AutoTokenizer.from_pretrained(args.basemodel)
    print("loading tokenizer")

    if args.unconverted:

        config = AutoConfig.from_pretrained(args.basemodel)
        if "tiny" in args.basemodel:
            config.hidden_size = 128
        else:
            config.hidden_size = 768

        # if the victim is also tiny, we can just keep the same number of dimensions.
        if "tiny" in args.basemodel:
            config.output_size = 128
        else:
            config.output_size = 768
        args.model_to_load = args.model
        args.pooler = "cls"
        model = load_BERT_for_CL(args.model, args, config, tokenizer)
    else:
        model = AutoModel.from_pretrained(args.model)

    for name, param in model.named_parameters():
        # if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False
    model = model.to(device)
    model.eval()
    print("Finish loading models")

    # parameters needed for the evaluator but not important for us.
    args.outputdir = "."
    args.logdir = "."
    args.seed = 42
    #metrics = trainer_eval(args, model, tokenizer, batcher_type='standard', eval_senteval_transfer=True)
    #print(metrics)

    if args.subset:
        if "flickr" in args.train1:
            TRAIN1_FRAC = 20000
            TRAIN2_FRAC = 3000
            TEST_FRAC = 3000
        else:
            TRAIN1_FRAC = 35000
            TRAIN2_FRAC = 15000
            TEST_FRAC = 15000
    else:
        if "flickr" in args.train1:
            TRAIN1_FRAC = 45000
            TRAIN2_FRAC = 14000
            TEST_FRAC = 14000
        else:
            TRAIN1_FRAC = 45000
            TRAIN2_FRAC = 20000
            TEST_FRAC = 20000


    train_1 = pd.read_csv(args.train1, sep=',', header=None).sample(
        TRAIN1_FRAC).squeeze().astype(dtype=str).to_list()  # .sample(frac=0.02).iloc[:, 0].to_list()
    train_2 = pd.read_csv(args.train2, sep=',', header=None).sample(
        TRAIN2_FRAC).squeeze().astype(dtype=str).to_list()  # .sample(frac=0.02).iloc[:, 0].to_list()
    test = pd.read_csv(args.test, sep=',', header=None).sample(
        TEST_FRAC).squeeze().astype(dtype=str).to_list()  # .sample(frac=0.05).iloc[:, 0].to_list()

    # try to only keep sentences above a certain length
    #train_1 = [train_1[i] for i in range(len(train_1)) if len(train_1[i].split()) > 3]
    #train_2 = [train_2[i] for i in range(len(train_2)) if len(train_2[i].split()) > 3]
    #test = [test[i] for i in range(len(test)) if len(test[i].split()) > 3]

    print(len(train_1), len(train_2), len(test))
    print("Finish loading data")

    num_dimension = args.num_dims
    batch_size = args.bs
    # training_representations_1 = torch.zeros(len(train_1), num_dimension)
    # training_representations_2 = torch.zeros(len(train_2), num_dimension)
    # training_test_representations = torch.zeros(10000, 512)
    # test_representations = torch.zeros(len(test), num_dimension)

    num_batch_train1 = len(train_1) // batch_size
    num_batch_train2 = len(train_2) // batch_size
    num_batch_test = len(test) // batch_size

    training_representations_1 = torch.zeros(num_batch_train1 * batch_size, num_dimension)
    training_representations_2 = torch.zeros(num_batch_train2 * batch_size, num_dimension)
    test_representations = torch.zeros(num_batch_test * batch_size, num_dimension)

    # get representations for test 1
    for i in range(num_batch_train1):
        # print(i)
        x_batch = train_1[i * batch_size:(i + 1) * batch_size]
        # print(x_batch)
        sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type=PROCESSING_METHOD)
        embeddings = return_di_embeddings(model, sent_features, args.use_pooler, return_sent_emb=args.unconverted)
        training_representations_1[i * batch_size: (i + 1) * batch_size] = embeddings

    print("Finish generating training representations 1")

    for i in range(num_batch_train2):
        x_batch = train_2[i * batch_size:(i + 1) * batch_size]
        sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type=PROCESSING_METHOD)
        embeddings = return_di_embeddings(model, sent_features, args.use_pooler, return_sent_emb=args.unconverted)
        training_representations_2[i * batch_size: (i + 1) * batch_size] = embeddings

    print("Finish generating training representations 2")

    for i in range(num_batch_test):
        x_batch = test[i * batch_size:(i + 1) * batch_size]
        sent_features = get_sent_features(args, x_batch, tokenizer, sent_features_type=PROCESSING_METHOD)
        embeddings = return_di_embeddings(model, sent_features, args.use_pooler, return_sent_emb=args.unconverted)
        test_representations[i * batch_size: (i + 1) * batch_size] = embeddings

    print("Finish generating test representations")

    # normalize
    training_representations_1 = (training_representations_1 - torch.mean(training_representations_1,
                                                                          dim=0)) / torch.std(
        training_representations_1, dim=0)
    training_representations_2 = (training_representations_2 - torch.mean(training_representations_2,
                                                                          dim=0)) / torch.std(
        training_representations_2, dim=0)
    test_representations = (test_representations - torch.mean(test_representations, dim=0)) / torch.std(
        test_representations, dim=0)
    training_representations_1 = F.normalize(training_representations_1)
    training_representations_2 = F.normalize(training_representations_2)

    test_representations = F.normalize(test_representations)

    training_representations_1 = training_representations_1.cpu().detach().numpy()
    training_representations_2 = training_representations_2.cpu().detach().numpy()
    test_representations = test_representations.cpu().detach().numpy()

    print("Finish normalizing representations")

    n_components= args.gmm_components
    print(f"N_COMPONENTS: {n_components}")
    gm = GaussianMixture(n_components=n_components, max_iter=1000, covariance_type="diag")
    gm.fit(training_representations_1)

    # with open('../gmm_models/imagenet_gmm_10_augmentation.pkl', 'wb') as f:
    #  pickle.dump(gm, f)

    print("Finish fitting GMM")

    training_likelihood_1 = gm.score_samples(training_representations_1)
    training_likelihood_2 = gm.score_samples(training_representations_2)
    test_likelihood = gm.score_samples(test_representations)

    """
    upper_bd = 1e8
    lower_bd = -upper_bd
    training_likelihood_1 = training_likelihood_1[(training_likelihood_1 <= upper_bd) & (training_likelihood_1 >= lower_bd)]
    training_likelihood_2 = training_likelihood_2[(training_likelihood_2 <= upper_bd) & (training_likelihood_2 >= lower_bd)]
    test_likelihood = test_likelihood[(test_likelihood <= upper_bd) & (test_likelihood >= lower_bd)]
    """

    print("training likelihood 1: " + str(np.mean(training_likelihood_1)))
    print("training likelihood 2: " + str(np.mean(training_likelihood_2)))
    print("test likelihood: " + str(np.mean(test_likelihood)))
    # print(str(np.mean(training_likelihood_2) - np.mean(test_likelihood)))
    print(stats.ttest_ind(training_likelihood_2, test_likelihood))

if __name__ == '__main__':
    main()
