import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import string
import lda
import argparse

import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

class LinearClassifier(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearClassifier, self).__init__()
        self.linear1 = torch.nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear1(x)

def train_model_LR(input, Dim_in, Dim_out, labels, test_input, test_labels, LR):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    D_in, D_out = Dim_in, Dim_out

    #0-1 normalize each feature of input
    x = input / input.max(0, keepdim=True)[0]

    # map string label to float
    label_map = {}
    for i, label in enumerate(set(labels)):
        label_map[label] = i
    labels = [label_map[label] for label in labels]
    y = torch.from_numpy(np.array(labels)).long()

    #make train and val sets
    dataset = TensorDataset(x, y)
    train_size = int(0.8 * len(dataset))
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size], generator=torch.Generator().manual_seed(42))
    loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

    model = LinearClassifier(D_in, D_out).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    best_acc = 0
    best_model = None

    for epoch in range(100):
        running_loss = 0.0
        progress_bar = tqdm(enumerate(loader), total=len(loader), desc="Epoch {}".format(epoch+1), leave=False)
        with torch.set_grad_enabled(True):
            for i, (batch_x, batch_y) in progress_bar:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                loss = criterion(y_pred, batch_y)
                running_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        #print training loss and accuracy
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                _, predicted = torch.max(y_pred.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
        print(f"Train Loss: {running_loss/len(loader)}, Train Accuracy: {correct/total}")

        # evaluate on val set
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = model(batch_x)
                loss = criterion(y_pred, batch_y)
                val_loss += loss.item()
                _, predicted = torch.max(y_pred.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
        print(f"Val Loss: {val_loss/len(val_loader)}, Val Accuracy: {correct/total}")
    
        # save best model
        if correct/total > best_acc:
            best_acc = correct/total
            best_model = model.state_dict()
    
    # evaluate on test set
    test_x = test_input / test_input.max(0, keepdim=True)[0]
    test_labels = [label_map[label] for label in test_labels]
    test_y = torch.from_numpy(np.array(test_labels)).long()
    test_dataset = TensorDataset(test_x, test_y)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        print('getting test accuracy')
        for batch_x, batch_y in test_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            y_pred = model(batch_x)
            loss = criterion(y_pred, batch_y)
            test_loss += loss.item()
            _, predicted = torch.max(y_pred.data, 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()
    test_accuracy = correct/total
    print(f"Test Loss: {test_loss/len(test_loader)}, Test Accuracy: {correct/total}")

    return test_accuracy, best_model, best_acc


def train_model(input, Dim_in, Dim_out, labels, test_input, test_labels):
    
    LR_range = [1e-2, 5e-3, 1e-3, 5e-5, 1e-4, 5e-5, 1e-5, 5e-6, 1e-6]
    best_acc = 0
    best_model = None
    best_test_acc = 0

    for LR in LR_range:
        test_acc, model, acc = train_model_LR(input, Dim_in, Dim_out, labels, test_input, test_labels, LR)
        if acc > best_acc:
            best_acc = acc
            best_model = model
            best_test_acc = test_acc
    
    return best_test_acc, best_model