import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import string
import lda
import argparse

import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

#--------------------------------------------#

class ImageNetClassifier(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(ImageNetClassifier, self).__init__()
        self.linear1 = torch.nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear1(x)


def train_model(input, Dim_in, Dim_out, labels):
    # Check if CUDA is available and set PyTorch to use GPU or CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device: ', device)

    D_in, D_out = Dim_in, Dim_out

    #0-1 normalize each feature of input
    x = input / input.max(0, keepdim=True)[0]

    # map string label to float
    label_map = {}
    for i, label in enumerate(set(labels)):
        label_map[label] = i
    labels = [label_map[label] for label in labels]
    y = torch.from_numpy(np.array(labels)).long()

    #make train and test sets
    dataset = TensorDataset(x, y)
    train_size = int(0.8 * len(dataset))
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size], generator=torch.Generator().manual_seed(42))
    loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

    model = ImageNetClassifier(D_in, D_out).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_acc = 0
    best_model = None

    for epoch in range(100):
        running_loss = 0.0
        progress_bar = tqdm(enumerate(loader), total=len(loader), desc="Epoch {}".format(epoch+1), leave=False)
        with torch.set_grad_enabled(True):
            for i, (batch_x, batch_y) in progress_bar:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                loss = criterion(y_pred, batch_y)
                running_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        #print training loss and accuracy
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                _, predicted = torch.max(y_pred.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
        print(f"Train Loss: {running_loss/len(loader)}, Train Accuracy: {correct/total}")

        # evaluate on test set
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in test_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                loss = criterion(y_pred, batch_y)
                test_loss += loss.item()
                _, predicted = torch.max(y_pred.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
        print(f"Test Loss: {test_loss/len(test_loader)}, Test Accuracy: {correct/total}")
    
        # save best model
        if correct/total > best_acc:
            best_acc = correct/total
            best_model = model.state_dict()
    
    return best_acc, best_model

#--------------------------------------------#

def load_data(PATH, NAME):

    #read tsv 
    df = pd.read_csv(PATH, sep='\t')

    #construct documents to be a string of all segment clusters in each image
    images = np.unique(df["image_path"])
    documents = {x: [] for x in images}

    clusters = np.unique(df["cluster_label"])
    cluster_words = {x:"" for x in clusters}
    for w in cluster_words:
        cluster_words[w] = ''.join(random.choices(string.ascii_letters + string.digits, k=6))

    # iterate through df
    for i, row in tqdm(df.iterrows()):
        image_path = row["image_path"]
        cluster_label = row["cluster_label"]
        cluster_label = str(cluster_label)
        documents[image_path].append(cluster_label)

    image_ids = []
    labels = []
    messages = []
    for d in documents.keys():
        image_ids.append(d)
        labels.append(d.split("/")[0])
        messages.append(" ".join([cluster_words[int(x)] for x in documents[d]]))

    df_messages = pd.DataFrame({"message_id": range(len(image_ids)), "image_id": image_ids, "message": messages, "label": labels})

    #write cluster_words to csv file
    df_cluster_words = pd.DataFrame({"cluster_label": list(cluster_words.keys()), "cluster_word": list(cluster_words.values())})
    df_cluster_words.to_csv("cluster_words_" + NAME + ".csv", index=False)
    df_messages.to_csv("messages_" + NAME + ".csv", index=False)

    print(df_messages.head())

    return df_messages

#--------------------------------------------#

def get_cluster_memberships(df_messages):
    # for each document, get a vector in the form of [0, 1, 0, 1, 0, 0, ...] where 1 indicates membership in that cluster

    num_documents = len(df_messages)
    clusters = np.unique((" ").join(list(df_messages["message"])).split(" "))
    num_clusters = len(clusters)
    
    print("num_documents", num_documents)
    print("num_clusters:", num_clusters)

    cluster_to_index = {}
    for i, cluster in enumerate(clusters):
        cluster_to_index[cluster] = i

    cluster_membership_matrix = np.zeros((num_documents, num_clusters))
    for i, row in tqdm(df_messages.iterrows()):
        message = row["message"]
        cluster_counts = {}
        for cluster in message.split(" "):
            cluster_counts[cluster] = cluster_counts.get(cluster, 0) + 1
        for cluster in cluster_counts.keys():
            cluster_membership_matrix[i, cluster_to_index[cluster]] = cluster_counts[cluster]

    return cluster_membership_matrix


