import os
import numpy as np
import sklearn.metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pathlib
import shap_dataset.shap_dataset
import utils
from global_variables import TEST_SET_RATIO, RANDOM_SEED




class NNDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, device='cpu'):
        self.X = dataset.x
        self.y = dataset.y
        self.no_samples = self.X.shape[0]
        self.no_features = self.X.shape[1]
        self.generator = torch.Generator(device=device)
        self.generator.manual_seed(RANDOM_SEED)

    def __len__(self):
        return self.no_samples

    # get a row at an index
    def __getitem__(self, idx):
        return [self.X[idx], self.y[idx]]

    # get indexes for train and test rows
    def get_splits(self, test_ratio=TEST_SET_RATIO):
        # determine sizes
        test_size = round(test_ratio * len(self.X))
        train_size = len(self.X) - test_size
        # calculate the split
        return torch.utils.data.random_split(self, [train_size, test_size], generator=self.generator)


class Net(nn.Module):
    def __init__(self, name, layer):
        super(Net, self).__init__()
        self.no_layers = len(layer) - 1
        self.linear_layers = nn.ParameterList([nn.Linear(layer[i], layer[i + 1], dtype=torch.float64) for i in range(self.no_layers)])
        self.name = name
        self.logs = []

    def forward(self, x):
        for i, linear in enumerate(self.linear_layers):
            if i == self.no_layers-1:
                x = linear(x)
            else:
                x = F.relu(linear(x))
        return x

    def custom_forward(self, x):
        """This function is passed to kernel shap which passes in numpy arrays to the forward pass not torch tensors
        """
        with torch.no_grad():
            if next(self.parameters()).is_cuda:
                x = torch.tensor(x).cuda()
            else:
                x = torch.tensor(x)
            for i, linear in enumerate(self.linear_layers):
                if i == self.no_layers - 1:
                    x = linear(x)
                else:
                    x = F.relu(linear(x))
            return x.cpu().numpy()

def prepare_data(dataset):
    # calculate split
    train, test = dataset.get_splits(test_ratio=TEST_SET_RATIO)
    # prepare data loaders
    train_dl = DataLoader(train, batch_size=128, shuffle=True)
    test_dl = DataLoader(test, batch_size=2**18, shuffle=False)
    return train_dl, test_dl


def train_model(train_dl, test_dl, model, save_interval=20):
    # define the optimization
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_mse, best_val_r2 = float('inf'), -1*float('inf')
    # enumerate epochs
    for epoch in tqdm.tqdm(range(100)):
        # enumerate mini batches
        total_loss = 0
        for i, (inputs, targets) in enumerate(train_dl):
            # clear the gradients
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            # compute the model output
            yhat = model(inputs).squeeze()
            # calculate loss
            loss = criterion(yhat, targets)
            # credit assignment
            loss.backward()
            # update model weights
            optimizer.step()
            total_loss += loss.item()/len(train_dl)

        # Evaluate after an epoch
        train_r2, train_mse = evaluate_model(train_dl, model)
        val_r2, val_mse = evaluate_model(test_dl, model)
        print(f"Epoch {epoch}", total_loss, val_r2, val_mse)

        # Update logs
        model.logs.append(
            {
                "train_mse_loss": train_mse,
                "val_mse_loss": val_mse,
                "train_r2": train_r2,
                "val_r2": val_r2,
            }
        )

        # Save model
        if val_mse < best_val_mse:
            save_model(model, epoch, best=True)
            plot_logs(model)
            best_val_mse = val_mse
            best_val_r2 = val_r2
        if (epoch + 1) % save_interval == 0:
            save_model(model, epoch)
    
    print("Best scores:", best_val_mse, best_val_r2)


def save_model(model, epoch, best=False):
    this_directory = pathlib.Path(__file__).parent.resolve()
    model_dir = f"{this_directory}/cache/{model.name}"
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    checkpoint = { 
            'epoch': epoch,
            'model': model.state_dict(),
            'logs': model.logs,
        }
    
    if best:
        torch.save(checkpoint, f'{model_dir}/best.pth')
    else:
        torch.save(checkpoint, f'{model_dir}/{epoch}.pth')


def load_model(task_name, epoch=None, best=False, device="cpu"):
    this_directory = pathlib.Path(__file__).parent.resolve()
    model_dir = f"{this_directory}/cache/{task_name}"
    if best:
        checkpoint_file = f'{model_dir}/best.pth'
    else:
        checkpoint_file = f'{model_dir}/{epoch}.pth'

    dataset_no_features = utils.get_task_settings()["no_features"]
    pass
    if os.path.exists(checkpoint_file):
        if device=="cpu":
            checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
            model = Net(task_name, [dataset_no_features[task_name], 300, 300, 300, 1])
            model.load_state_dict(checkpoint["model"])
            model.logs = checkpoint["logs"]
        else:
            torch.cuda.set_device(device)
            checkpoint = torch.load(checkpoint_file, map_location=device)
            model = Net(task_name, [dataset_no_features[task_name], 300, 300, 300, 1]).cuda()
            model.load_state_dict(checkpoint["model"])
            model.logs = checkpoint["logs"]
        # else:
        #     raise ValueError(f"device type {device} is unknown. Must be either GPU or CPU")

    else:
        raise Exception(f"Could not find {checkpoint_file} to load from.")

    return model

def evaluate_model(test_dl, model):
    predictions, actuals = list(), list()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(test_dl):
            # evaluate the model on the test set
            inputs = inputs.cuda()
            yhat = model(inputs).squeeze()
            # retrieve numpy array
            yhat = yhat.cpu().numpy()
            actual = targets.numpy()
            # store
            predictions.append(yhat)
            actuals.append(actual)
    predictions, actuals = np.hstack(predictions), np.hstack(actuals)
    # calculate mse
    r2 = sklearn.metrics.r2_score(predictions, actuals)
    mse = sklearn.metrics.mean_squared_error(predictions, actuals)
    return r2, mse

def compute_output(dataloader, model):
    inputs, predictions = list(), list()
    with torch.no_grad():
        for i, (batch_input, targets) in enumerate(dataloader):
            # evaluate the model on the test set
            batch_input = batch_input.cuda()
            yhat = model(batch_input).squeeze()
            # retrieve numpy array
            yhat = yhat.cpu().numpy()
            batch_input = batch_input.cpu().numpy()
            # store
            predictions.append(yhat)
            inputs.append(batch_input)
    inputs = np.vstack(inputs)
    predictions = np.hstack(predictions)
    return inputs, predictions



def plot_logs(model):
    loss_data = {"train_mse_loss": [l["train_mse_loss"] for l in model.logs],
                "val_mse_loss": [l["val_mse_loss"] for l in model.logs]}
    r2_data = {"train_r2": [l["train_r2"] for l in model.logs],
                "val_r2": [l["val_r2"] for l in model.logs]}

    fig, axes = plt.subplots(1, 2, figsize=(8, 3), sharex=True)
    fig.suptitle(model.name)

    sns.lineplot(data=loss_data, ax=axes[0])
    sns.lineplot(data=r2_data, ax=axes[1])
    axes[1].set_ylim(0, 1)

    # Save plot
    data_directory = os.environ.get("EXPERIMENT_CACHE") if "EXPERIMENT_CACHE" in os.environ else os.getcwd()
    model_dir = f"{data_directory}/models/{model.name}"
    plt.savefig(f"{model_dir}/plot.png")

if __name__ == "__main__":
    # # Harvard Dataset
    # num_features = [20, 30, 40 ,50, 60]
    # for feature_count in num_features:
    #     model_name = f"harvard{feature_count}"
    #     dataset = shap_dataset.shap_dataset.HarvardCleanEnergyDataset(feature_count)
    #     # Dataset in torch format
    #     nn_dataset = NNDataset(dataset)
    #     # Dataloaders
    #     train_dl, test_dl = prepare_data(nn_dataset)
    #     model = Net(model_name, [nn_dataset.no_features, 300, 300, 300, 1]).cuda()
    #     # train model
    #     train_model(train_dl, test_dl, model)

    # num_features = [20, 30, 40, 50, 60]
    # for feature_count in num_features:
    #     model_name = f"avGFP{feature_count}"
    #     dataset = shap_dataset.shap_dataset.avGFPDataset()
    #     dataset.select_important_features(feature_count)
    #     # Dataset in torch format
    #     nn_dataset = NNDataset(dataset)
    #     # Dataloaders
    #     train_dl, test_dl = prepare_data(nn_dataset)
    #     model = Net(model_name, [nn_dataset.no_features, 300, 300, 300, 1]).cuda()
    #     # train model
    #     train_model(train_dl, test_dl, model)

    # # Entacmaea Dataset
    # model_name = "entacmaea"
    # dataset = shap_dataset.shap_dataset.EntacmaeaDataset()
    # # Dataset in torch format
    # nn_dataset = NNDataset(dataset)
    # # Dataloaders
    # train_dl, test_dl = prepare_data(nn_dataset)
    # model = Net(model_name, [nn_dataset.no_features, 300, 300, 300, 1]).cuda()
    # # train model
    # train_model(train_dl, test_dl, model)


    # # Crimes Dataset
    # model_name = "crimes"
    # dataset = shap_dataset.shap_dataset.CrimesDataset()
    # # Dataset in torch format
    # nn_dataset = NNDataset(dataset)
    # # Dataloaders
    # train_dl, test_dl = prepare_data(nn_dataset)
    # print(model_name, len(train_dl), len(test_dl), nn_dataset.no_features)
    # model = Net(model_name, [nn_dataset.no_features, 300, 300, 300, 1]).cuda()
    # # train model
    # train_model(train_dl, test_dl, model)

    # SGEMM Dataset
    model_name = "sgemm"
    dataset = shap_dataset.shap_dataset.SGEMMDataset()
    # Dataset in torch format
    nn_dataset = NNDataset(dataset)
    # Dataloaders
    train_dl, test_dl = prepare_data(nn_dataset)
    model = Net(model_name, [nn_dataset.no_features, 300, 300, 300, 1]).cuda()
    # train model
    train_model(train_dl, test_dl, model)
