import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import random
import string
import lda
import argparse

import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import os

#--------------------------------------------#

class ImageNetClassifier(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(ImageNetClassifier, self).__init__()
        self.linear1 = torch.nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear1(x)


def train_model_lrsearch(input_train, labels_train, input_val, labels_val, input_test, labels_test, Dim_in, Dim_out, classes, return_val_acc=False):
    
    LR_range = [1e-2, 5e-3, 1e-3, 5e-5, 1e-4, 5e-5, 1e-5, 5e-6, 1e-6]
    best_acc = 0
    best_model = None
    best_test_acc = 0

    for LR in LR_range:
        test_acc, model, acc = train_model(input_train, labels_train, input_val, labels_val, 
                                                            input_test, labels_test, Dim_in, Dim_out, classes, 
                                                            lr=LR, return_model=True)
        # print('acc', acc)
        # print('best_acc', best_acc)
        # import pdb
        # pdb.set_trace()
        if acc > best_acc:
            best_acc = acc
            best_model = model
            best_test_acc = test_acc
    if return_val_acc:
        return best_test_acc, best_model, best_acc
    return best_test_acc, best_model


def train_model(input_train, labels_train, input_val, labels_val, input_test, labels_test, Dim_in, Dim_out, classes, 
                lr=5e-3, epochs=100, verbose=0, return_model=False):
    # Check if CUDA is available and set PyTorch to use GPU or CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device: ', device)
    print('verbose', verbose)

    D_in, D_out = Dim_in, Dim_out

    #0-1 normalize each feature of input
    inputs_all = torch.concat([input_train, input_val, input_test])
    x_train = input_train / inputs_all.max(0, keepdim=True)[0]
    x_val = input_val / inputs_all.max(0, keepdim=True)[0]
    x_test = input_test / inputs_all.max(0, keepdim=True)[0]

    # map string label to float
    label_map = {}
    for i, label in enumerate(classes):
        label_map[label] = i
    labels_train = [label_map[label] for label in labels_train]
    labels_val = [label_map[label] for label in labels_val]
    labels_test = [label_map[label] for label in labels_test]
    y_train = torch.from_numpy(np.array(labels_train)).long()
    y_val = torch.from_numpy(np.array(labels_val)).long()
    y_test = torch.from_numpy(np.array(labels_test)).long()

    #make train and test sets
    train_dataset = TensorDataset(x_train, y_train)
    val_dataset = TensorDataset(x_val, y_val)
    test_dataset = TensorDataset(x_test, y_test)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = ImageNetClassifier(D_in, D_out).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_acc = 0
    best_model = None
    best_model_state_dict = None

    for epoch in range(epochs):
        running_loss = 0.0

        with torch.set_grad_enabled(True):
            for i, (batch_x, batch_y) in enumerate(train_loader):
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                loss = criterion(y_pred, batch_y)
                running_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        #print training loss and accuracy
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in train_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                _, predicted = torch.max(y_pred.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
        if verbose:
            print(f"Train Loss: {running_loss/len(train_loader)}, Train Accuracy: {correct/total}")

        # evaluate on val set
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                loss = criterion(y_pred, batch_y)
                val_loss += loss.item()
                _, predicted = torch.max(y_pred.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
        if verbose:
            print(f"Val Loss: {val_loss/len(val_loader)}, Val Accuracy: {correct/total}")
    
        # save best model
        if correct/total > best_acc:
            best_acc = correct/total
            best_model = model
            best_model_state_dict = model.state_dict()

    # evaluate on test set
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            y_pred = best_model(batch_x)
            loss = criterion(y_pred, batch_y)
            test_loss += loss.item()
            _, predicted = torch.max(y_pred.data, 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()
    print(f"Test Loss: {test_loss/len(test_loader)}, Test Accuracy: {correct/total}")
    best_acc_test = correct/total
    
    if return_model:
        return best_acc_test, best_model, best_acc
    return best_acc_test, best_model_state_dict

#--------------------------------------------#

def load_data(PATH, NAME, TEST_PATH, CLASSES_PATH, save_message=False):
    # ----
    #read tsv 
    df = pd.read_csv(PATH, sep='\t')

    #construct documents to be a string of all segment clusters in each image
    images = np.unique(df["image_path"])
    documents = {x: [] for x in images}

    clusters = np.unique(df["cluster_label"])
    cluster_words = {x:"" for x in clusters}
    for w in cluster_words:
        cluster_words[w] = ''.join(random.choices(string.ascii_letters + string.digits, k=6))

    # iterate through df
    image2label = {}
    for i, row in tqdm(df.iterrows()):
        image_path = row["image_path"]
        label = row["class_label"]
        if image_path not in image2label:
            image2label[image_path] = label
        cluster_label = row["cluster_label"]
        cluster_label = str(cluster_label)
        documents[image_path].append(cluster_label)

    image_ids = []
    labels = []
    messages = []
    for d in documents.keys():
        image_ids.append(d)
        # labels.append(d.split("/")[0])
        labels.append(image2label[d])
        messages.append(" ".join([cluster_words[int(x)] for x in documents[d]]))

    df_messages = pd.DataFrame({"message_id": range(len(image_ids)), 
                                "image_id": image_ids, 
                                "message": messages, 
                                "label": labels})

    # ----
    if os.path.exists(os.path.join(TEST_PATH, "trainval.txt")):
        trainval_examples = pd.read_csv(os.path.join(TEST_PATH, "trainval.txt"), sep='\t')['image_path']
        train_size = int(0.8 * len(trainval_examples))
        train_dataset, val_dataset = torch.utils.data.random_split(trainval_examples, [train_size, len(trainval_examples) - train_size], 
                                                                    generator=torch.Generator().manual_seed(42))
        train_image_paths = [train_dataset.dataset[i] for i in train_dataset.indices]
        val_image_paths = [val_dataset.dataset[i] for i in val_dataset.indices]
    else:
        train_examples = pd.read_csv(os.path.join(TEST_PATH, "train.txt"), sep='\t')['image_path']
        val_examples = pd.read_csv(os.path.join(TEST_PATH, "val.txt"), sep='\t')['image_path']
        train_image_paths = train_examples.values
        val_image_paths = val_examples.values
    test_examples = pd.read_csv(os.path.join(TEST_PATH, "test.txt"), sep='\t')['image_path']
    test_image_paths = test_examples.values

    print("train examples size: ", len(train_image_paths))
    print("val examples size: ", len(val_image_paths))
    print("test examples size: ", len(test_image_paths))
    

    #remove test examples from training set
    df_train = df_messages[df_messages["image_id"].isin(train_image_paths)]
    df_val = df_messages[df_messages["image_id"].isin(val_image_paths)]
    df_test = df_messages[df_messages["image_id"].isin(test_image_paths)]
    

    #write cluster_words to csv file
    df_cluster_words = pd.DataFrame({"cluster_label": list(cluster_words.keys()), "cluster_word": list(cluster_words.values())})
    df_cluster_words.to_csv("ldas/cluster_words_" + NAME + ".csv", index=False)
    if save_message:
        df_messages.to_csv("ldas/messages_" + NAME + ".csv", index=False)

    print(df_messages.head())

    # Load classes
    classes = pd.read_csv(CLASSES_PATH, header=None)[0].values
    
    return df_train, df_val, df_test, classes

#--------------------------------------------#

def get_cluster_memberships(df_train, df_val, df_test):
    # for each document, get a vector in the form of [0, 1, 0, 1, 0, 0, ...] where 1 indicates membership in that cluster
    df_messages = pd.concat([df_train, df_val, df_test], ignore_index=True)

    num_documents = len(df_messages)
    clusters = np.unique((" ").join(list(df_messages["message"])).split(" "))
    num_clusters = len(clusters)
    
    print("num_documents", num_documents)
    print("num_clusters:", num_clusters)

    cluster_to_index = {}
    for i, cluster in enumerate(clusters):
        cluster_to_index[cluster] = i

    cluster_membership_matrix = np.zeros((num_documents, num_clusters))
    for i, row in tqdm(df_messages.iterrows()):
        message = row["message"]
        cluster_counts = {}
        for cluster in message.split(" "):
            cluster_counts[cluster] = cluster_counts.get(cluster, 0) + 1
        for cluster in cluster_counts.keys():
            cluster_membership_matrix[i, cluster_to_index[cluster]] = cluster_counts[cluster]

    cluster_membership_matrix_train = cluster_membership_matrix[:len(df_train)]
    cluster_membership_matrix_val = cluster_membership_matrix[len(df_train):len(df_train)+len(df_val)]
    cluster_membership_matrix_test = cluster_membership_matrix[len(df_train)+len(df_val):]
    return cluster_membership_matrix_train, cluster_membership_matrix_val, cluster_membership_matrix_test


