import torch
from torch import nn
from torch.nn import functional as F
import numpy as np


class Perceptron(nn.Module):
    def __init__(self, input_size):
        super(Perceptron, self).__init__()
        self.fc = nn.Linear(input_size, 1)

    def forward(self, history_ids):
        out = self.fc(history_ids)
        out = torch.sigmoid(out)
        return out


class LinearRegression(torch.nn.Module):
    def __init__(self, inputSize, outputSize):
        super(LinearRegression, self).__init__()
        self.linear = torch.nn.Linear(inputSize, outputSize)

    def forward(self, x):
        out = self.linear(x)
        return out


import numpy as np
import random


class QLearningAgent:
    def __init__(self, actions, alpha=0.5, gamma=0.99, epsilon=0.1):
        """
        Initialize the Q-Learning Agent.

        Parameters:
        - actions: list of possible actions
        - alpha: learning rate (0 < alpha <= 1)
        - gamma: discount factor (0 <= gamma < 1)
        - epsilon: exploration rate (0 <= epsilon <= 1)
        """
        self.q_table = (
            {}
        )  # Dictionary to store Q-values, format: {state: [values for actions]}
        self.actions = actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon

    def get_q_values(self, state):
        """
        Returns the Q-values for all actions in a given state,
        initializing them if the state is not already in the table.
        """
        return self.q_table.setdefault(state, [0.0 for _ in self.actions])

    def get_action(self, state):
        """
        Select an action based on an epsilon-greedy policy.
        """
        if random.random() < self.epsilon:
            # Explore: select a random action
            return random.choice(self.actions)
        else:
            # Exploit: select the action with the highest Q-value
            q_values = self.get_q_values(state)
            max_q_value = max(q_values)
            # In case multiple actions have the same Q-value, choose randomly among them
            return random.choice(
                [
                    action
                    for action, q_value in zip(self.actions, q_values)
                    if q_value == max_q_value
                ]
            )

    def learn(self, state, action, reward, next_state):
        """
        Update the Q-value for the state-action pair based on the observation.
        """
        current_q_values = self.get_q_values(state)
        next_q_values = self.get_q_values(next_state)

        # Q-Learning update rule
        current_q = current_q_values[action]
        # Take the maximum Q-value for the next state
        max_next_q = max(next_q_values)

        # Compute the target Q-value
        target_q = reward + (self.gamma * max_next_q)

        # Update the Q-value for the state-action pair
        current_q_values[action] += self.alpha * (target_q - current_q)

    def update_epsilon(self, decrement=0.01, min_epsilon=0.01):
        """
        Decrease epsilon to reduce the amount of exploration over time.
        """
        self.epsilon = max(min_epsilon, self.epsilon - decrement)


def transform_to_2d_with_labels(data, n):
    transformed_sequences = []
    labels = []
    for i in range(len(data)):
        # Check if the current value is 2 or 3, indicating a potential end of a sequence
        if data[i] in (2, 3):
            # Extract the sequence of length n ending with the current value
            start_index = max(0, i - n + 1)
            sequence = data[start_index : i + 1]
            # Pad the sequence if necessary
            if len(sequence) < n:
                padding = torch.full(
                    (n - len(sequence),), data[0]
                )  # Pad with the first element of the tensor
                sequence = torch.cat((padding, sequence), dim=0)
            transformed_sequences.append(sequence.tolist())
            labels.append(data[i].item())  # Store the label (either 2 or 3)
    return torch.tensor(transformed_sequences), torch.tensor(labels)


def trainer(dataset_name, gamma, n=20):
    # Do offline training in preparation for pp
    datapath = f"data/{dataset_name}_history_ids_{gamma}.npy"

    data = np.load(datapath)

    train_data = data[: int(0.8 * len(data))]
    test_data = data[int(0.8 * len(data)) :]

    # Flatten and remove values of -1
    train_data = train_data.flatten()
    train_data = train_data[train_data != -1]

    test_data = test_data.flatten()
    test_data = test_data[test_data != -1]
    # convert to tensor
    train_data = torch.tensor(train_data)
    test_data = torch.tensor(test_data)

    # convert data to tensor
    history_ids = torch.tensor(data).float()

    # Transform the sequence and get labels
    train_data_mat, train_data_labels = transform_to_2d_with_labels(train_data, n=20)
    test_data_mat, test_data_labels = transform_to_2d_with_labels(test_data, n=20)
    train_data_mat = train_data_mat.float()
    test_data_mat = test_data_mat.float()
    train_data_labels = train_data_labels.unsqueeze(1) - 2
    test_data_labels = test_data_labels.unsqueeze(1) - 2

    model = Perceptron(n)

    # train the model
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    for epoch in range(10):
        # Forward pass: Compute predicted y by passing x to the model
        y_pred = model(train_data_mat)
        # Compute and print loss
        loss = criterion(y_pred, train_data_labels.float())
        print(f"Epoch {epoch + 1}: train loss: {loss.item()}")
        # Zero gradients, perform a backward pass, and update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # print the label balance
    print("Label balance")
    print(train_data_labels.sum() / len(train_data_labels))

    # train acc
    y_pred = y_pred > 0.5
    correct = (y_pred == train_data_labels).sum().item()
    total = train_data_labels.size(0)
    print(f"Train accuracy: {correct/total}")

    # test the model
    y_pred = model(test_data_mat)
    # Compute and print loss
    loss = criterion(y_pred, test_data_labels.float())
    print(f"Test loss: {loss.item()}")

    # accuracy
    y_pred = y_pred > 0.5
    correct = (y_pred == test_data_labels).sum().item()
    total = test_data_labels.size(0)
    print(f"Accuracy: {correct/total}")

    # save model to disk
    torch.save(model.state_dict(), f"models/{dataset_name}_{gamma}_{n}.pt")


# 1 is accepted, and should be marked green
# 2 is rejected, and should be marked red
# 3 is resampled, and should be marked blue
if __name__ == "__main__":

    datapath = "../data/gsm8k_history_ids_4.npy"
    datapath = "/jet/home/bpark1/llm-inference/data/finance-alpaca_history_ids_4.npy"

    data = np.load(datapath)

    print(data[0])
    train_data = data[: int(0.8 * len(data))]
    test_data = data[int(0.8 * len(data)) :]
    print(train_data.shape)
    print(test_data.shape)

    # Flatten and remove values of -1
    train_data = train_data.flatten()
    train_data = train_data[train_data != -1]

    test_data = test_data.flatten()
    test_data = test_data[test_data != -1]
    # convert to tensor
    train_data = torch.tensor(train_data)
    test_data = torch.tensor(test_data)

    # convert data to tensor
    history_ids = torch.tensor(data).float()
    n = 20

    # Example sequence
    data = train_data

    transformed_data, labels = transform_to_2d_with_labels(data, n)
    print(transformed_data)
    print(labels)
    transformed_data_tensor = torch.tensor(transformed_data)
    labels_tensor = torch.tensor(labels)

    transformed_data_tensor.shape, labels_tensor.shape, labels_tensor

    # Transform the sequence and get labels
    train_data_mat, train_data_labels = transform_to_2d_with_labels(train_data, n=20)
    test_data_mat, test_data_labels = transform_to_2d_with_labels(test_data, n=20)
    train_data_mat = train_data_mat.float()
    test_data_mat = test_data_mat.float()
    train_data_labels = train_data_labels.unsqueeze(1) - 2
    test_data_labels = test_data_labels.unsqueeze(1) - 2

    # convert labels to either 0 or 1
    print(test_data_mat)
    print(test_data_labels)

    print("SHAPES")
    print(train_data_mat.shape, train_data_labels.shape)
    print(test_data_mat.shape, test_data_labels.shape)
    # transform 1d array to 2d array based on n
    # the last element of each row will be 1 right before it changes to 2 or 3

    # # merge the train and test
    # train_data_mat = torch.cat((train_data_mat, test_data_mat), dim=0)
    # train_data_labels = torch.cat((train_data_labels, test_data_labels), dim=0)

    model = Perceptron(n)

    # train the model
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    for epoch in range(1):
        # Forward pass: Compute predicted y by passing x to the model
        y_pred = model(train_data_mat)
        # Compute and print loss
        loss = criterion(y_pred, train_data_labels.float())
        print(f"Epoch {epoch + 1}: train loss: {loss.item()}")
        # Zero gradients, perform a backward pass, and update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # print the label balance
    print("Label balance")
    print(train_data_labels.sum() / len(train_data_labels))

    # train acc
    y_pred = y_pred > 0.5
    correct = (y_pred == train_data_labels).sum().item()
    total = train_data_labels.size(0)
    print(f"Train accuracy: {correct/total}")

    # test the model
    y_pred = model(test_data_mat)
    # Compute and print loss
    loss = criterion(y_pred, test_data_labels.float())
    print(f"Test loss: {loss.item()}")

    # accuracy
    y_pred = y_pred > 0.5
    correct = (y_pred == test_data_labels).sum().item()
    total = test_data_labels.size(0)
    print(f"Accuracy: {correct/total}")

    # save model to disk
    # torch.save(model.state_dict(), "../models/gsm8k.pt")
    torch.save(model.state_dict(), "../models/finance-alpaca_4.pt")
