import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from sklearn.metrics import accuracy_score

from utils import torch_predict


def eval(model, criterion, device, X_train, y_train, X_val, y_val, train_acc_, val_loss_, val_acc_, avg_rew_, avg_dist_, iter, env=None):
    # Validation loss
    model.eval()

    with torch.no_grad():
        full_train_out = model(X_train.to(device))
        full_train_loss = criterion(full_train_out, y_train.to(device)).item()
        #full_train_loss_["curve"].append(full_train_loss)
        #full_train_loss_["iter"].append(iter)

        if X_val is not None:
            val_out = model(X_val.to(device))
            val_loss = criterion(val_out, y_val.to(device)).item()
            val_loss_["curve"].append(val_loss)
            val_loss_["iter"].append(iter)

            val_acc = accuracy_score(y_val, torch_predict(model, X_val))
            val_acc_["curve"].append(val_acc)
            val_acc_["iter"].append(iter)
        else:
            val_acc = None

        train_acc = accuracy_score(y_train, torch_predict(model, X_train))
        train_acc_["curve"].append(train_acc)
        train_acc_["iter"].append(iter)

        if env is not None:
            avg_rew, avg_dist = env.evaluate_policy(model, num_episodes=25, device=device)
            avg_rew_["curve"].append(avg_rew)
            avg_rew_["iter"].append(iter)
            avg_dist_["curve"].append(avg_dist)
            avg_dist_["iter"].append(iter)

    return full_train_loss, train_acc, val_acc


def train_torch_model(model, X_train, y_train, X_val=None, y_val=None, batch_size="full", lr=1e-3, max_iter=np.inf, env=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if batch_size == "full":
        batch_size = X_train.shape[0]  # full batch
    val_freq = 500
    model = model.to(device)
    X_train = torch.tensor(X_train, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.long)
    if X_val is not None:
        X_val = torch.tensor(X_val, dtype=torch.float32)
        y_val = torch.tensor(y_val, dtype=torch.long)
    train_ds = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    #optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    train_loss_ = {"curve": [], "iter":[]}
    #full_train_loss_ = {"curve": [], "iter":[]}
    train_acc_ = {"curve": [], "iter":[]}
    val_loss_ = {"curve": [], "iter":[]}
    val_acc_ = {"curve": [], "iter":[]}
    avg_rew_ = {"curve": [], "iter": []}
    avg_dist_ = {"curve": [], "iter": []}
    train_acc = 0.0
    iter = 0
    epoch = 0

    while True:
        for xb, yb in train_loader:
            model.train()
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()

            train_loss_["curve"].append(loss.item())
            train_loss_["iter"].append(iter)

            if iter % val_freq == 0:
                full_train_loss, train_acc, val_acc = eval(model, criterion, device, X_train, y_train, X_val, y_val, train_acc_, val_loss_, val_acc_, avg_rew_, avg_dist_, iter, env=env)
                print(f"iter: {iter}, loss: {loss.item()}, train acc: {train_acc}, test acc: {val_acc}, out: {out[np.random.randint(0, out.shape[0])]}")
                
                if iter >= max_iter or full_train_loss < 1e-7:
                    return model, {"train_loss": train_loss_, "train_acc": train_acc_, "test_loss": val_loss_, "test_acc": val_acc_, "avg_rew": avg_rew_, "avg_dist": avg_dist_}

            iter += 1
        epoch += 1