import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import QuantileTransformer
from tabularbert import TabularBERTTrainer
from tabularbert.utils.metrics import ClassificationError, RMSE
from tabularbert.utils.data import UniformDiscretize

import json
import os
import random

def seed_everything(seed=0):
    '''
    Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.
    '''
    random.seed(seed)
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
# Load model parameters
with open('config.json', 'r') as f:
    config = json.load(f)

# Set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Datasets are available at: https://github.com/jyansir/t2g-former
# Load and preprocess data
train_X = np.load("data/jannis/X_num_train.npy")
valid_X = np.load("data/jannis/X_num_val.npy")
test_X = np.load("data/jannis/X_num_test.npy")

train_labels = np.load("data/jannis/y_train.npy")
valid_labels = np.load("data/jannis/y_val.npy")
test_labels = np.load("data//jannis/y_test.npy")

scaler = QuantileTransformer(n_quantiles=max(min(train_X.shape[0] // 30, 1000), 10),
                             output_distribution='uniform',
                             subsample=None)

scaler.fit(train_X)
train_XX = scaler.transform(train_X)
valid_XX = scaler.transform(valid_X)
test_XX = scaler.transform(test_X)

encoding_info = None

# Pretraining
trainer = TabularBERTTrainer(x=train_XX,
                             num_bins=50,
                             encoding_info=encoding_info,
                             device=device)
trainer.setup_directories_and_logging(save_dir='./pretraining',
                                    phase='pretraining',
                                    project_name='JA data pretraining',
                                    use_wandb=False)
trainer.set_bert(embedding_dim=config['embedding_dim'], 
                 n_layers=config['n_layers'],
                 n_heads=config['n_heads'],
                 dropout=config['pretraining_dropout'],
                 mode=config['mode'])
trainer.set_optimizer(lr=2e-4, weight_decay=config['weight_decay'])
trainer.pretrain(lamb=config['lamb'],
                penalty='squaredL2',
                epochs=1000,
                batch_size=1024,
                mask_token_prob=0.15,
                random_token_prob=0.1,
                unchanged_token_prob=0.1,
                num_workers=0)


# Fine-tuning
trainer = TabularBERTTrainer.from_pretrained(save_path = './pretraining/version0/model_checkpoint.pt',
                                                device = device)
trainer.setup_directories_and_logging(save_dir='./fine-tuning',
                                    phase='fine-tuning',
                                    project_name='JA data fine-tuning',
                                    use_wandb=False)
trainer.set_head(output_dim=4, 
                 activation='ReLU',
                 dropouts=config['fine_tuning_dropout'],
                 hidden_layers=[config['embedding_dim']] * config['head_n_hidden_layers'])
trainer.set_optimizer(lr=config['fine_tuning_lr'],
                     weight_decay=config['weight_decay'])
trainer.finetune(x=train_XX,
                y=train_labels,
                valid_x=valid_XX,
                valid_y=valid_labels,
                epochs=2000,
                penalty=config['fine_tuning_penalty'],
                batch_size=1024,
                criterion=nn.CrossEntropyLoss(),
                metric=ClassificationError(),
                patience=100,
                num_workers=0)


# Evaluation
discretizer = UniformDiscretize(num_bins=50, encoding_info=encoding_info)
discretizer.fit(train_XX)
test_XXX = discretizer.discretize(test_XX)

model = TabularBERTTrainer.from_finetuned(f"./fine-tuning/version0/model_checkpoint.pt",
                                            device=device)
model.eval()

predictions = model(torch.tensor(test_XXX).to(device))
pred_class = predictions.argmax(dim=1)
accuracy = (test_labels == pred_class.cpu().numpy()).mean()
print(accuracy)