import polars as pl
import numpy as np
import torch
import argparse
import copy
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report

from torch.utils.data import DataLoader
from utilities import check_path, balance_df, get_torch_device, flatten_images

# ============== Global variables ========================


# ============== Model design ============================
class NNClassifier(torch.nn.Module):
    def __init__(self, num_coordinates, num_classes, n_inner_layers=3, space_function=np.linspace, dropout=0.3, activation=True, activation_function=torch.nn.ReLU):
        super(NNClassifier, self).__init__()
        n_steps     = 2 + n_inner_layers
        self.steps  = space_function(num_coordinates, num_classes, n_steps).astype(int)
        # TODO: A dirty fix, will get back to it later
        self.steps[0]   = num_coordinates
        self.layers = torch.nn.ModuleList()
        self.dropout    = dropout
        self.activation = activation
        self.act_func   = activation_function
        for i in range(n_steps - 1):
            if i != 0:
                self.layers.append(torch.nn.Dropout(p=dropout))
                if activation:
                    self.layers.append(self.act_func())
            self.layers.append(torch.nn.Linear(self.steps[i], self.steps[i+1]))


    def forward(self, X):
        out = X
        for layer in self.layers:
            out = layer(out)
        return out

def train_nn_classifier(model, train_dataloader, eval_dataloader, optimizer, num_epochs, loss_fn, learning_rate, verbose=True, device=torch.device("cpu")):
    model_tr    = copy.deepcopy(model)
    model_tr.train()
    model_tr.to(device)
    optimizer   = optimizer(model_tr.parameters(), lr=learning_rate)
    #scheduler   = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=10)
    loss_all_epochs = []
    test_acc_all_epochs = []

    if verbose:
        print(f"Epoch [0/{num_epochs}], Loss: N/A, Test acc: {eval_nn_classifier(model_tr, eval_dataloader, device=device):.4f}%")
    total = len(train_dataloader)
    for epoch in range(num_epochs):
        loss_current_epoch = 0
        for sentences, labels in tqdm(train_dataloader, total=total):
            labels = labels.to(device)
            y_predicted = model_tr(sentences)
            y_predicted = y_predicted.squeeze_()
            loss = loss_fn(y_predicted, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step(loss)
            loss_current_epoch += loss.item()
        
        loss_all_epochs.append(loss_current_epoch)
        test_acc = eval_nn_classifier(model_tr, eval_dataloader, device=device)
        test_acc_all_epochs.append(test_acc)
        if verbose:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss_current_epoch / total:.4f}, Test acc: {test_acc:.4f}%")
        
    return model_tr, loss_all_epochs, test_acc_all_epochs

def eval_nn_classifier(model, eval_dataloader, device=torch.device("cpu")):
    model.eval()
    model.to(device)
    with torch.no_grad():
        correct = 0
        total   = 0
        for analogies, labels in eval_dataloader:
            labels      = labels.to(device)
            y_predicted = model(analogies)
            _, label_predicted  = torch.max(y_predicted.data, 1)
            label_predicted     = label_predicted.squeeze()
            total       += labels.size(0)
            correct     += (label_predicted==labels).sum().item()
    accuracy    = 100 * correct / total
    return accuracy

def test_nn_classifier(model, test_analogies):
    model.eval()
    labels_predicted    = []
    with torch.no_grad():
        for analogy in tqdm(test_analogies, desc="Predicting"):
            y_predicted = model(analogy)
            _, label_predicted  = torch.max(y_predicted.data, 0)
            labels_predicted.append(label_predicted.squeeze())
    return labels_predicted

def test_nn_classifier_on_df(df_test:pl.DataFrame, model:torch.nn.Module, device="cpu"):
    if "analogy" not in df_test.columns and "encoded_analogy" in df_test.columns:
        _analogy    = pl.Series(flatten_images(df_test.select("encoded_analogy").to_series().to_list())).alias("analogy")
        df_test     = df_test.with_columns(_analogy)
    X_test  = df_test.select("analogy").to_series()
    y_test  = df_test.select("is_valid").to_series()
    # Move the eval tensors to the device
    X_test  = torch.Tensor(X_test).to(device)
    y_pred  = test_nn_classifier(model.to(device), X_test)
    y_pred  = [item.item() for item in y_pred]
    p, r, f1, s = precision_recall_fscore_support(y_test, y_pred, labels=[0,1])
    accuracy    = accuracy_score(y_test, y_pred)
    return (p,r,f1,s), accuracy, classification_report(y_test, y_pred)

# ============== Utility functions ======================

def init_weights(model):
    if isinstance(model, torch.nn.Linear) or isinstance(model, torch.nn.Conv2d):
        torch.nn.init.xavier_uniform_(model.weight.data)
        model.bias.data.fill_(0.01)
    return

# ============= Parameterise the script =================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Try NN classification approaches")
    parser.add_argument("data_dir", type=str, default="data/11/", help="Data directory")
    parser.add_argument("--n_checkpoints", type=int, default=1, help="Number of checkpoints. 1 if using the whole dataset at once")
    parser.add_argument("--n_epochs", type=int, default=10, help="Number of epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for data loader")
    parser.add_argument("--export", type=int, default=1, help="Choose whether or not to export the model")
    parser.add_argument("--export_path", type=str, default="models/cls/", help="Export model path")

    args = parser.parse_args()

    arg_data_dir        = args.data_dir
    arg_n_checkpoints   = args.n_checkpoints
    arg_n_epochs        = args.n_epochs
    arg_learning_rate   = args.learning_rate
    arg_batch_size      = args.batch_size
    arg_export          = args.export
    arg_export_path     = args.export_path
    # ============== Real good amazing superb code ===========
    # A friendly reminder
    print("Revving up...")
    # For reproduction purposes IMPORTANT!!!!
    random_state    = 14
    torch.manual_seed(random_state)
    # In case arrows are used
    print("Loading train set from arrows...")
    df_train    = pl.read_json(f"{arg_data_dir}train.json")
    df_dev      = pl.read_json(f"{arg_data_dir}dev.json")
    # Device maybe
    device      = get_torch_device()
    # Get done with the dev set
    if "analogy" not in df_dev.columns and "encoded_analogy" in df_dev.columns:
        analogy_eval    = pl.Series(flatten_images(df_dev.select("encoded_analogy").to_series().to_list())).alias("analogy")
        df_dev  = df_dev.with_columns(analogy_eval)
    X_eval  = df_dev.select("analogy").to_series()
    y_eval  = df_dev.select("is_valid").to_series()
    # Move the eval tensors to the device
    X_eval      = torch.Tensor(X_eval).to(device)
    y_eval      = torch.LongTensor(y_eval).to(device)
    eval_set    = list(zip(X_eval,y_eval))
    # Instantiate model
    num_classes = 2
    num_coordinates = X_eval[0].shape[0]
    nils        = [2]
    spfs        = [np.geomspace]
    dropouts    = [0.3]
    activations = [True]
    act_funcs   = [torch.nn.ReLU]
    # Activate debugging mode
    # torch.autograd.set_detect_anomaly(True)

    # Hyperparameters  
    batch_size  = arg_batch_size
    n_epochs    = arg_n_epochs
    learning_rate   = arg_learning_rate
    loss_fn     = torch.nn.CrossEntropyLoss()
    optimizer   = torch.optim.Adam
    
    #   and data for evalidation
    eval_dataloader = DataLoader(eval_set, batch_size=batch_size)
    #   and training data
    #   1. Balance and shuffle the train set
    df_train    = balance_df(df_train, on="is_valid").sample(fraction=1, shuffle=True, seed=random_state)
    if "analogy" not in df_train.columns and "encoded_analogy" in df_train.columns:
        analogy_eval    =  pl.Series(flatten_images(df_train.select("encoded_analogy").to_series().to_list())).alias("analogy")
        df_train    = df_train.with_columns(analogy_eval)
    X_train = df_train.select("analogy").to_series()
    y_train = df_train.select("is_valid").to_series()
    
    # Save the arguments for future uses
    df_args     = pl.from_dict(vars(args))
    df_args     = df_args.with_columns(pl.Series([num_coordinates]).alias("num_coordinates"))
    # df_args     = df_args.with_columns(pl.Series(["ff"]).alias("architecture"))
    
    # Actual training
    for n_inner_layers in nils:
        for func in spfs:
            for dropout in dropouts:
                for activation in activations:
                    for act_func in act_funcs:
                        # For model exports
                        # df_args_x   = df_args.with_columns(pl.Series([n_inner_layers]).alias("n_inner_layers"))
                        # df_args_x   = df_args_x.with_columns(pl.Series([func]).alias("space_function"))
                        # df_args_x   = df_args_x.with_columns(pl.Series([dropout]).alias("dropout"))
                        # df_args_x   = df_args_x.with_columns(pl.Series([activation]).alias("activation"))
                        # df_args_x   = df_args_x.with_columns(pl.Series([act_func]).alias("activation_function"))
                        model_name  = f"model_{n_inner_layers}_geomspace_{dropout}"
                        # Now get back to the real business
                        # For tracking the training process across checkpoints
                        training_stats  = list()
                        space_function  = func
                        model       = NNClassifier(num_coordinates, num_classes, n_inner_layers=n_inner_layers, 
                                                space_function=space_function, dropout=dropout, activation=activation,
                                                activation_function=act_func)
                        model.apply(init_weights)
                        # Model for the initial training
                        trained_model   = model
                        # Print number of params
                        print("Total number of parameters: ", sum(p.numel() for p in model.parameters()))
                        #   2. Create the indices for the data portions
                        indices    = np.linspace(0, X_train.shape[0], arg_n_checkpoints+1, dtype=int)
                        for i in range(arg_n_checkpoints):
                            print(f"At checkpoint {i+1}/{arg_n_checkpoints}")
                            onset   = indices[i]
                            offset  = indices[i+1]
                            print(f"Currently working on the portion from {onset} to {offset}...")
                            X_p_train   = torch.Tensor(X_train[onset:offset]).to(device)
                            y_p_train   = torch.LongTensor(y_train[onset:offset]).to(device)
                            train_p_set = list(zip(X_p_train,y_p_train))
                            # Create dataloader and ignore the last batch
                            train_dataloader    = DataLoader(train_p_set, batch_size=batch_size, drop_last=True)
                            # Train the created model
                            trained_model, loss_total, test_acc_total   = train_nn_classifier(trained_model, train_dataloader, eval_dataloader, optimizer, n_epochs, loss_fn, learning_rate, verbose=True, device=device)
                            stat_record = {
                                "checkpoint": i,
                                # "dataset"   : tfile,
                                "losses"    : loss_total,
                                "accuracy"  : test_acc_total
                            }
                            training_stats.append(stat_record)
                            if arg_export:
                                model_path  = arg_export_path + f"{model_name}/"
                                check_path(model_path)
                                # df_args_x.write_json(model_path + "args.json")
                                torch.save(trained_model.state_dict(), model_path + "model.pt")
                    
                        if arg_export:
                            pl.DataFrame(training_stats).write_json(model_path + "stats.json")
    # Testing
    df_test = pl.read_json(f"{arg_data_dir}test.json")
    prfs, accuracy, cls_report = test_nn_classifier_on_df(df_test, trained_model, device=device)

        
