import datasets
import torch
from torch import nn
from prodigyopt import Prodigy
from sklearn.model_selection import train_test_split
import argparse
import trackexp
import pyfamilywise
import copy

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str)
parser.add_argument('--loss_func', type=str, choices=['ce', 'fw'], required=True)
parser.add_argument('--patience', type=int, default=10)
parser.add_argument('--base_dir', type=str, default='trackexp_out', help='Base directory for trackexp outputs')
args = parser.parse_args()


patience = args.patience


trackexp.init(base_dir=args.base_dir, experiment_name = f"{args.dataset}__{args.loss_func}")
experimental_config = {
    "optimizer": "prodigy",
    "patience" : patience,
    "loss_func" : args.loss_func
}
trackexp.metadata(experimental_config)

# Check if CUDA is available
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
print(f"Using device: {device}")

X, y = datasets.load_trn(args.dataset, return_X_y=True)
X_test, y_test = datasets.load_tst(args.dataset, return_X_y=True)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, shuffle=True)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

output_dim = len(torch.unique(y_train_tensor))
input_dim = X.shape[1]

model = nn.Linear(input_dim, output_dim).to(device)
optimizer = Prodigy(model.parameters(), decouple=True)

if args.loss_func == 'ce':
    loss_fn = nn.CrossEntropyLoss()
elif args.loss_func == 'fw':
    loss_fn = pyfamilywise.FWLoss(num_classes=output_dim, device = device_str)

best_accuracy = 0
waiting = 0
best_model_state = copy.deepcopy(model.state_dict())

epoch = 0
while True:
    trackexp.start_timer("training", epoch)
    optimizer.zero_grad()
    outputs = model(X_train_tensor)
    loss = loss_fn(outputs, y_train_tensor)
    loss.backward()
    optimizer.step()

    trackexp.log("training", "loss", epoch, loss.item())

    with torch.no_grad():
        val_outputs = model(X_val_tensor)
        _, predicted = torch.max(val_outputs, 1)
        accuracy = (predicted == y_val_tensor).float().mean().item()
        trackexp.log("validation", "accuracy", epoch, accuracy)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = copy.deepcopy(model.state_dict())
            waiting = 0
        else:
            waiting += 1

        if waiting >= patience and epoch > 100:
            break
    trackexp.stop_timer("training", epoch)
    epoch += 1

# Load the best model
model.load_state_dict(best_model_state)

# Run the best model on the test dataset
with torch.no_grad():
    test_outputs = model(X_test_tensor)
    _, predicted = torch.max(test_outputs, 1)
    accuracy_at_best_val = (predicted == y_test_tensor).float().mean().item()
    trackexp.log("testing", "accuracy", None, accuracy_at_best_val)
