import polars as pl
import numpy as np
import torch
import argparse
import copy
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
from matplotlib import pyplot as plt

from torch.utils.data import DataLoader
from utilities import check_path, generate_image_prompts_from_onehots, get_torch_device
from collections import Counter
import const_prompts as cp
# ============== Global variables ========================

# ============== Classifier design ============================
class NNClassifier(torch.nn.Module):
    def __init__(self, num_coordinates, num_classes, n_inner_layers=3, space_function=np.linspace, dropout=0.3, activation=True, activation_function=torch.nn.ReLU):
        super(NNClassifier, self).__init__()
        n_steps     = 2 + n_inner_layers
        self.steps  = space_function(num_coordinates, num_classes, n_steps).astype(int)
        self.steps[0]   = num_coordinates
        self.layers = torch.nn.ModuleList()
        self.dropout    = dropout
        self.activation = activation
        self.act_func   = activation_function
        for i in range(n_steps - 1):
            if i != 0:
                self.layers.append(torch.nn.Dropout(p=dropout))
                if activation:
                    self.layers.append(self.act_func())
            self.layers.append(torch.nn.Linear(self.steps[i], self.steps[i+1]))


    def forward(self, X):
        out = X
        for layer in self.layers:
            out = layer(out)
        return out

def train_nn_classifier(model, train_dataloader, eval_dataloader, optimizer, num_epochs, loss_fn, learning_rate, verbose=True, device=torch.device("cpu")):
    model_tr    = copy.deepcopy(model)
    model_tr.train()
    model_tr.to(device)
    optimizer   = optimizer(model_tr.parameters(), lr=learning_rate)
    #scheduler   = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=10)
    loss_all_epochs = []
    test_acc_all_epochs = []

    if verbose:
        print(f"Epoch [0/{num_epochs}], Loss: N/A, Test acc: {eval_nn_classifier(model_tr, eval_dataloader, device=device):.4f}%")
    total = len(train_dataloader)
    for epoch in range(num_epochs):
        loss_current_epoch = 0
        for sentences, labels in tqdm(train_dataloader, total=total):
            labels = labels.to(device)
            y_predicted = model_tr(sentences)
            y_predicted = y_predicted.squeeze_()
            loss = loss_fn(y_predicted, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step(loss)
            loss_current_epoch += loss.item()
        
        loss_all_epochs.append(loss_current_epoch)
        test_acc = eval_nn_classifier(model_tr, eval_dataloader, device=device)
        test_acc_all_epochs.append(test_acc)
        if verbose:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss_current_epoch / total:.4f}, Test acc: {test_acc:.4f}%")
        
    return model_tr, loss_all_epochs, test_acc_all_epochs

def eval_nn_classifier(model, eval_dataloader, device=torch.device("cpu")):
    model.eval()
    model.to(device)
    with torch.no_grad():
        correct = 0
        total   = 0
        for analogies, labels in eval_dataloader:
            labels      = labels.to(device)
            y_predicted = model(analogies)
            _, label_predicted  = torch.max(y_predicted.data, 1)
            label_predicted     = label_predicted.squeeze()
            total       += labels.size(0)
            correct     += (label_predicted==labels).sum().item()
    accuracy    = 100 * correct / total
    return accuracy

def test_nn_classifier(model, test_analogies):
    model.eval()
    labels_predicted    = []
    with torch.no_grad():
        for analogy in tqdm(test_analogies, desc="Predicting"):
            y_predicted = model(analogy)
            _, label_predicted  = torch.max(y_predicted.data, 0)
            labels_predicted.append(label_predicted.squeeze())
    return labels_predicted

# ============== Text generator design ==================

class NNGenerator(torch.nn.Module):
    def __init__(self, classifier, *args, **kwargs):
        super(NNClassifier, self).__init__()
        self.classifier = classifier
    
    def forward(self, X):
        return [self.classifier(element) for element in X.T.reshape(4, -1)]

def test_nn_generator(classifier, test_triplets):
    classifier.eval()
    images_predicted    = []
    with torch.no_grad():
        for triplet in tqdm(test_triplets, desc="Generating"):
            labels_predicted    = classifier(triplet)
            indices_predicted   = torch.max(labels_predicted.data, 1).indices
            images_predicted.append(indices_predicted.tolist())
    return images_predicted

# ============== Utility functions ======================

def init_weights(model):
    if isinstance(model, torch.nn.Linear) or isinstance(model, torch.nn.Conv2d):
        torch.nn.init.xavier_uniform_(model.weight.data)
        model.bias.data.fill_(0.01)
    return


def flatten_images(images):
    flat_images = []
    for image in images:
        flat_images.append([item for cell in image for property in cell for cp in property for item in cp])
    return flat_images

def extract_unit_X_y(df:pl.DataFrame):
    if "analogy" not in df.columns and "encoded_analogy" in df.columns:
        _analogy= pl.Series(flatten_images(df.select("encoded_analogy").to_series().to_list())).alias("analogy")
        df      = df.with_columns(_analogy)
    analogy = df.select("analogy").to_series()
    all_tensors = torch.Tensor(analogy).reshape(analogy.len(),4,-1)
    X_  = all_tensors.index_select(1, torch.IntTensor([0,1,2]))
    y_  = all_tensors.index_select(1, torch.IntTensor([3])).squeeze(dim=1)
    return (X_, y_), all_tensors.reshape(-1,1,all_tensors.shape[2]).unique(dim=0)


# ============= Parameterise the script =================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Try NN classification approaches")
    parser.add_argument("data_dir", type=str, default="data/12/", help="Encoded arrows directory")
    parser.add_argument("--n_checkpoints", type=int, default=1, help="Number of checkpoints")
    parser.add_argument("--n_epochs", type=int, default=10, help="Number of epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size for data loader")
    parser.add_argument("--export", type=int, default=1, help="Choose whether or not to export the model")
    parser.add_argument("--export_path", type=str, default="models/ff/", help="Export model path")

    args = parser.parse_args()

    arg_data_dir        = args.data_dir
    arg_n_checkpoints   = args.n_checkpoints
    arg_n_epochs        = args.n_epochs
    arg_learning_rate   = args.learning_rate
    arg_batch_size      = args.batch_size
    arg_export          = args.export
    arg_export_path     = args.export_path
# ============== Real good amazing superb code ===========
    # A friendly reminder
    print("Revving up...")
    # For reproduction purposes
    random_state    = 14
    torch.manual_seed(random_state)
    # Aaaand device?
    device      = get_torch_device()
    # In case arrows are used
    print("Loading train set from arrows...")
    df_train    = pl.read_json(f"{arg_data_dir}train.json").filter(pl.col("is_valid")==1)
    df_dev      = pl.read_json(f"{arg_data_dir}dev.json").filter(pl.col("is_valid")==1)
    df_dict     = pl.read_json(f"{arg_data_dir}dict.json").with_row_index()
    # Prevent further hardcoding
    idx_abc = torch.IntTensor([0,1,2])
    idx_d   = torch.IntTensor([3])
    # Extract data from dataframe
    (X_train, labels_train), _  = extract_unit_X_y(df_train)
    X_train     = X_train.reshape(X_train.shape[0],1,-1).squeeze(1)
    #   and the labels for classification
    label_train_idx = [df_dict.filter(pl.col("encoded_image")==item)["index"][0] for item in labels_train.tolist()]
    y_train     = torch.LongTensor(label_train_idx)
    # Get done with the dev set for validation
    (X_val, labels_val), _  = extract_unit_X_y(df_dev)
    X_val       = X_val.reshape(X_val.shape[0],1,-1).squeeze(1)
    label_val_idx  = [df_dict.filter(pl.col("encoded_image")==item)["index"][0] for item in labels_val.tolist()]
    y_val       = torch.LongTensor(label_val_idx)
    val_set     = list(zip(X_val.to(device), y_val.to(device)))
    
    print("Now training")
    # Instantiate model
    num_classes = df_dict.height
    num_coordinates = X_val[0].shape[0]
    nils        = [1]
    spfs        = [np.geomspace]
    dropouts    = [0.3]
    activations = [True]
    act_funcs   = [torch.nn.ReLU]
    # Activate debugging mode
    # torch.autograd.set_detect_anomaly(True)

    # Hyperparameters  
    batch_size  = arg_batch_size
    n_epochs    = arg_n_epochs
    learning_rate   = arg_learning_rate
    loss_fn     = torch.nn.CrossEntropyLoss()
    optimizer   = torch.optim.Adam
    
    #   and data for evalidation
    val_dataloader = DataLoader(val_set, batch_size=batch_size, drop_last=True)
    # Save the arguments for future uses
    df_args     = pl.from_dict(vars(args))
    df_args     = df_args.with_columns(pl.Series([num_coordinates]).alias("num_coordinates"))
    
    # Actual training
    for n_inner_layers in nils:
        for func in spfs:
            for dropout in dropouts:
                for activation in activations:
                    for act_func in act_funcs:
                        model_name  = f"model_{n_inner_layers}_geomspace_{dropout}"
                        # Now get back to the real business
                        # For tracking the training process across checkpoints
                        training_stats  = list()
                        space_function  = func
                        model       = NNClassifier(num_coordinates, num_classes, n_inner_layers=n_inner_layers, 
                                                space_function=space_function, dropout=dropout, activation=activation,
                                                activation_function=act_func)
                        model.apply(init_weights)
                        # Model for the initial training
                        trained_model   = model
                        # Print number of params
                        print("Total number of parameters: ", 
                            sum(p.numel() for p in model.parameters()))
                        #   2. Create the indices for the data portions
                        indices    = np.linspace(0, X_train.shape[0], arg_n_checkpoints+1, dtype=int)
                        for i in range(arg_n_checkpoints):
                            print(f"At checkpoint {i+1}/{arg_n_checkpoints}")
                            onset   = indices[i]
                            offset  = indices[i+1]
                            print(f"Currently working on the portion from {onset} to {offset}...")
                            X_p_train   = torch.Tensor(X_train[onset:offset]).to(device)
                            y_p_train   = torch.LongTensor(y_train[onset:offset]).to(device)
                            train_p_set = list(zip(X_p_train,y_p_train))
                            # Create dataloader and ignore the last batch
                            train_dataloader    = DataLoader(train_p_set, batch_size=batch_size, drop_last=True)
                            # Train the created model
                            trained_model, loss_total, test_acc_total   = train_nn_classifier(trained_model, train_dataloader, val_dataloader, optimizer, n_epochs, loss_fn, learning_rate, verbose=True, device=device)
                            stat_record = {
                                "checkpoint": i,
                                # "dataset"   : tfile,
                                "losses"    : loss_total,
                                "accuracy"  : test_acc_total
                            }
                            training_stats.append(stat_record)
                            if arg_export:
                                model_path  = arg_export_path + f"{model_name}/"
                                check_path(model_path)
                                # df_args_x.write_json(model_path + "args.json")
                                torch.save(trained_model.state_dict(), model_path + "model.pt")
                    
                        if arg_export:
                            pl.DataFrame(training_stats).write_json(model_path + "stats.json")
    # Testing
    print("Testing classifier on images")
    df_test = pl.read_json(f"{arg_data_dir}test.json").filter(pl.col("is_valid")==1)
    (X_test, labels_test), _ = extract_unit_X_y(df_test)
    X_test  = X_test.reshape(X_test.shape[0],1,-1).squeeze(1)
    y_test = [df_dict.filter(pl.col("encoded_image")==item)["index"][0] for item in labels_test.tolist()]
    # Move the eval tensors to the device
    X_test  = torch.Tensor(X_test).to(device)
    y_pred  = test_nn_classifier(trained_model, X_test)
    y_pred  = [item.item() for item in y_pred]
    p, r, f1, s = precision_recall_fscore_support(y_test, y_pred)
    accuracy    = accuracy_score(y_test, y_pred)
    # Decode the predicted tokens
    y_test_embeddings   = [df_dict.filter(pl.col("index")==item)["original_image"][0] for item in y_test]
    y_pred_embeddings   = [df_dict.filter(pl.col("index")==item)["original_image"][0] for item in y_pred]
    y_test_texts    = [generate_image_prompts_from_onehots(embedding,cp.IMAGE_TEMPLATES[2], cp.SHAPES, cp.COLOURS) for embedding in y_test_embeddings]
    y_pred_texts    = [generate_image_prompts_from_onehots(embedding,cp.IMAGE_TEMPLATES[2], cp.SHAPES, cp.COLOURS) for embedding in y_pred_embeddings]
    df  = pl.DataFrame([y_test_texts,y_pred_texts], schema=["ref","pred"])
    

    