import json
import torch
import random
import numpy as np # type: ignore
import pandas as pd # type: ignore
from torch import nn
import torch.nn.functional as F
from torchvision import datasets # type: ignore
import torchvision.transforms as transforms # type: ignore

def set_seed(seed=42):
    random.seed(seed)                     # Python
    np.random.seed(seed)                  # NumPy
    torch.manual_seed(seed)               # PyTorch (CPU)
    torch.cuda.manual_seed(seed)          # PyTorch (GPU)
    torch.cuda.manual_seed_all(seed)      # Multi-GPU
    torch.backends.cudnn.deterministic = True


######################################################################## MODELS ########################################################################
class fashionConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        # 1st convolution layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=2, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # 2nd convolution layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # Flatten layer
        self.flatten = nn.Flatten()

        # Dense layer with dropout
        self.fc1 = nn.Linear(in_features=64*5*5, out_features=512)
        self.drop1 = nn.Dropout2d(p=0.3)

        # Dense layer with dropout again
        self.fc2 = nn.Linear(in_features=512, out_features=128)
        self.drop2 = nn.Dropout2d(p=0.2)

        # Layer normalization and output layer
        self.layerNorm = nn.LayerNorm(128)  # Use LayerNorm instead of BatchNorm
        self.fc3 = nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        out = self.conv1(x)  # Conv layer
        out = self.conv2(out)  # Conv layer
        out = self.flatten(out)  # Flattening

        out = self.fc1(out)  # Dense layer
        out = self.drop1(out)  # Dropout

        out = self.fc2(out)  # Dense layer
        out = self.drop2(out)  # Dropout

        out = self.layerNorm(out)  # Layer normalization
        out = self.fc3(out)  # Final output

        return out

######################################################################## Automata utils ########################################################################

def print_automata(automata, labels):
    header = f"{'state':<15}" + f"{'distance':<10}| " + "".join(f"{label:<13}" for label in labels)
    print(header)
    print("-" * len(header))

    for idx, (state, row) in enumerate(zip(automata["states"], automata["transition_matrix"])):
        state_info = ""
        if state in automata.get("accepting_states", []):
            state_info = "(Accepting)"
        elif state in automata.get("deadlock_states", []):
            state_info = "(Deadlock)"
        
        row_str = "".join(f"{val}".ljust(13) for val in row)
        distance = automata["distances"][str(idx+1)]
        
        print(f"{str(state) + state_info:<16}{str(distance):<10}| {row_str}")

def import_automata(filename):

    with open(filename, "r") as json_file:
        automata = json.load(json_file)

    return automata

######################################################################## Ordered Fashion MNIST utils ########################################################################

def import_ordered_fashion_mnist(fashion_mnist_path, ordered_fashion_mnist_path, train=True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = datasets.FashionMNIST(root=fashion_mnist_path, train=train, transform=transform, download=True)

    csv_path = ordered_fashion_mnist_path
    csv_data = pd.read_csv(csv_path)

    custom_indices = [eval(row[0]) for row in csv_data.values]
    custom_labels = [eval(row[1]) for row in csv_data.values]

    return dataset, custom_indices, custom_labels

def ordered_fashion_mnist_accuracy(labels, pred_labels):
    # Flatten lists
    flat1 = [item for sublist in labels for item in sublist]
    flat2 = [item for sublist in pred_labels for item in sublist]

    # Label-level accuracy
    label_matches = sum(a == b for a, b in zip(flat1, flat2))
    label_total = len(flat1)
    label_accuracy = label_matches / label_total

    # Sequence-level accuracy
    sequence_matches = sum(a == b for a, b in zip(labels, pred_labels))
    sequence_total = len(labels)
    sequence_accuracy = sequence_matches / sequence_total

    # Print results
    #print(f"{'Test Accuracy:':<25}{label_accuracy*100:>9.2f}%")
    #print(f"{'Sequence Test Accuracy:':<25}{sequence_accuracy*100:>9.2f}%")

    return label_accuracy*100, sequence_accuracy*100

######################################################################## CNN experiments ########################################################################

def train_cnn_without_automata(dataloader, model, loss_fn, optimizer, device="cpu"):
    
    size = len(dataloader.dataset)
    
    model.train()

    for index, (image, label) in enumerate(dataloader):
        image, label = image.to(device), label.to(device)
        pred = model(image)
        loss = loss_fn(pred, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if index % 100 == 0:
            loss = loss.item()
            current = index * len(image)
            print(f"loss: {loss:.3f} | progress: {(current / size) * 100:.2f}%")

def test_cnn_without_automata(test_custom_indices, test_custom_labels, model, test_dataset, device="cpu"):
    
    size = len([idx for sublist in test_custom_indices for idx in sublist])
    sequence_size = len(test_custom_indices)
    model.eval()
    model.to(device)
    correct = 0
    correct_sequence = 0

    for current_outfit in range(len(test_custom_indices)):

        current_sequence = []

        for i, idx in enumerate(test_custom_indices[current_outfit]):

            image, label = test_dataset[idx]
            image, label  = image.to(device), label#.to(device)
            image = image.unsqueeze(0)
            
            pred = model(image)

            correct += (pred.argmax(1) == label).type(torch.float).sum().item()
            current_sequence.append(pred.argmax(1).item())

        if current_sequence == test_custom_labels[current_outfit]:
            correct_sequence += 1

    accuracy = (correct/size)*100
    sequence_accuracy = (correct_sequence/sequence_size)*100
    print(f"{'Test Accuracy:':<25}{(accuracy):>9.2f}%")
    print(f"{'Sequence Test Accuracy:':<25}{(sequence_accuracy):>9.2f}%")

    return accuracy, sequence_accuracy

def constraints_pred(pred, current_state, automata):
    
    row = automata["transition_matrix"][current_state - 1]

    constrains = torch.tensor([0 if element not in automata["deadlock_states"] else float('-inf') for element in row])

    return pred + constrains

def test_cnn_with_automata(test_custom_indices, model, test_dataset, automata, k, device="cpu"):

    results = []

    for current_outfit in range(len(test_custom_indices)):

        current_state = 1
        paths = [([current_state], 0.0, 0.0, automata["distances"][str(current_state)], [])]
        
        for idx in test_custom_indices[current_outfit]:

            candidates = []

            image, label = test_dataset[idx]

            image = image.unsqueeze(0)
            image, label = (image.to(device), label.to(device)) if type(label) == torch.Tensor else (image.to(device), torch.tensor(label).to(device))
            pred = model(image)

            if len(pred) > 1:
                pred = pred[0]

            for state_history, score, score_automata, _, pred_history in paths:

                constrained_pred = constraints_pred(pred, state_history[-1], automata)

                probs_automata = torch.nn.functional.softmax(constrained_pred, dim=-1)
                _, top_indices = torch.topk(probs_automata, k)
                probs_cnn = torch.nn.functional.softmax(pred, dim=-1)

                for index in range(len(top_indices.squeeze().tolist())):
                    
                    current_idx = top_indices.squeeze()[index].item()

                    new_score_automata = score_automata - np.log(probs_automata.squeeze(0)[current_idx].item())
                    new_score = score - np.log(probs_cnn.squeeze(0)[current_idx].item())

                    if np.isnan(new_score_automata):
                        new_score_automata = np.float64(np.inf)
                    if np.isnan(new_score):
                        new_score = np.float64(np.inf) 
                    
                    new_state_history = state_history + [automata["transition_matrix"][state_history[-1] - 1][top_indices.squeeze()[index].item()]]

                    candidates.append((new_state_history, new_score, new_score_automata, automata["distances"][str(new_state_history[-1])], pred_history + [top_indices.squeeze()[index].item()]))

            paths = sorted(candidates, key=lambda x: x[2])[:k]
            #print(paths)
        #At the last step we want the best beam according to the natural score:
        paths = sorted(paths, key=lambda x: x[1])
        #print(paths)
        results.append(paths[0][4])

    return results

######################################################################## Additional experiments ########################################################################
#### beam search

def test_cnn_with_beam_search(test_custom_indices, model, test_dataset, k, device="cpu"):

    results = []

    for current_outfit in range(len(test_custom_indices)):

        paths = [(0.0, [])]
        
        for idx in test_custom_indices[current_outfit]:

            candidates = []

            image, label = test_dataset[idx]

            image = image.unsqueeze(0)
            image, label = (image.to(device), label.to(device)) if type(label) == torch.Tensor else (image.to(device), torch.tensor(label).to(device))
            pred = model(image)

            if len(pred) > 1:
                pred = pred[0]

            for score, pred_history in paths:

                probs_cnn = torch.nn.functional.softmax(pred, dim=-1)
                _, top_indices = torch.topk(probs_cnn, k)

                for index in range(len(top_indices.squeeze().tolist())):
                    
                    current_idx = top_indices.squeeze()[index].item()
                    new_score = score - np.log(probs_cnn.squeeze(0)[current_idx].item())

                    if np.isnan(new_score):
                        new_score = np.float64(np.inf) 
                    
                    candidates.append((new_score, pred_history + [top_indices.squeeze()[index].item()]))

            paths = sorted(candidates, key=lambda x: x[0])[:k]

        paths = sorted(paths, key=lambda x: x[0])
        results.append(paths[0][1])

    return results

#### automata scores
def test_cnn_with_automata_scores(test_custom_indices, model, test_dataset, automata, k, device="cpu"):

    results = []

    for current_outfit in range(len(test_custom_indices)):

        current_state = 1
        paths = [([current_state], 0.0, 0.0, automata["distances"][str(current_state)], [])]
        
        for idx in test_custom_indices[current_outfit]:

            candidates = []

            image, label = test_dataset[idx]

            image = image.unsqueeze(0)
            image, label = (image.to(device), label.to(device)) if type(label) == torch.Tensor else (image.to(device), torch.tensor(label).to(device))
            pred = model(image)

            if len(pred) > 1:
                pred = pred[0]

            for state_history, score, score_automata, _, pred_history in paths:

                constrained_pred = constraints_pred(pred, state_history[-1], automata)

                probs_automata = torch.nn.functional.softmax(constrained_pred, dim=-1)
                _, top_indices = torch.topk(probs_automata, k)
                probs_cnn = torch.nn.functional.softmax(pred, dim=-1)

                for index in range(len(top_indices.squeeze().tolist())):
                    
                    current_idx = top_indices.squeeze()[index].item()

                    new_score_automata = score_automata - np.log(probs_automata.squeeze(0)[current_idx].item())
                    new_score = score - np.log(probs_cnn.squeeze(0)[current_idx].item())

                    if np.isnan(new_score_automata):
                        new_score_automata = np.float64(np.inf)
                    if np.isnan(new_score):
                        new_score = np.float64(np.inf) 
                    
                    new_state_history = state_history + [automata["transition_matrix"][state_history[-1] - 1][top_indices.squeeze()[index].item()]]

                    candidates.append((new_state_history, new_score, new_score_automata, automata["distances"][str(new_state_history[-1])], pred_history + [top_indices.squeeze()[index].item()]))

            paths = sorted(candidates, key=lambda x: x[2])[:k]

        paths = sorted(paths, key=lambda x: x[2])
        results.append(paths[0][4])

    return results

#### current output and not only at the end of the sequence
def test_cnn_with_automata_stepwise(test_custom_indices, model, test_dataset, automata, k, device="cpu"):

    results = []

    for current_outfit in range(len(test_custom_indices)):
        stepwise_output = []
        current_state = 1
        paths = [([current_state], 0.0, 0.0, automata["distances"][str(current_state)], [])]
        
        for idx in test_custom_indices[current_outfit]:

            candidates = []

            image, label = test_dataset[idx]

            image = image.unsqueeze(0)
            image, label = (image.to(device), label.to(device)) if type(label) == torch.Tensor else (image.to(device), torch.tensor(label).to(device))
            pred = model(image)

            if len(pred) > 1:
                pred = pred[0]

            for state_history, score, score_automata, dist, pred_history in paths:

                constrained_pred = constraints_pred(pred, state_history[-1], automata)

                probs_automata = torch.nn.functional.softmax(constrained_pred, dim=-1)
                _, top_indices = torch.topk(probs_automata, k)
                probs_cnn = torch.nn.functional.softmax(pred, dim=-1)

                if (state_history, score, score_automata, dist, pred_history) == paths[0]:
                    stepwise_output.append(top_indices.squeeze().tolist()[0])

                for index in range(len(top_indices.squeeze().tolist())):
                    
                    current_idx = top_indices.squeeze()[index].item()

                    new_score_automata = score_automata - np.log(probs_automata.squeeze(0)[current_idx].item())
                    new_score = score - np.log(probs_cnn.squeeze(0)[current_idx].item())

                    if np.isnan(new_score_automata):
                        new_score_automata = np.float64(np.inf)
                    if np.isnan(new_score):
                        new_score = np.float64(np.inf) 
                    
                    new_state_history = state_history + [automata["transition_matrix"][state_history[-1] - 1][top_indices.squeeze()[index].item()]]

                    candidates.append((new_state_history, new_score, new_score_automata, automata["distances"][str(new_state_history[-1])], pred_history + [top_indices.squeeze()[index].item()]))

            paths = sorted(candidates, key=lambda x: x[2])[:k]
        
        results.append(stepwise_output)

    return results