import json
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

### Automata functions ###

def print_automata(automata, labels):
    header = f"{'state':<15}" + f"{'distance':<10}| " + "".join(f"{label:<13}" for label in labels)
    print(header)
    print("-" * len(header))

    for idx, (state, row) in enumerate(zip(automata["states"], automata["transition_matrix"])):
        state_info = ""
        if state in automata.get("accepting_states", []):
            state_info = "(Accepting)"
        elif state in automata.get("deadlock_states", []):
            state_info = "(Deadlock)"
        
        row_str = "".join(f"{val}".ljust(13) for val in row)
        distance = automata["distances"][str(idx+1)]
        
        print(f"{str(state) + state_info:<16}{str(distance):<10}| {row_str}")

def import_automata(filename):

    with open(filename, "r") as json_file:
        automata = json.load(json_file)

    return automata

### Ordered Fashion MNIST functions ###

def import_ordered_fashion_mnist(fashion_mnist_path, ordered_fashion_mnist_path, train=True, seed=42):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = datasets.FashionMNIST(root=fashion_mnist_path, train=train, transform=transform, download=True)

    csv_path = ordered_fashion_mnist_path
    csv_data = pd.read_csv(csv_path)

    custom_indices = [eval(row[0]) for row in csv_data.values]
    custom_labels = [eval(row[1]) for row in csv_data.values]

    combined = list(zip(custom_indices, custom_labels))
    random.seed(seed)
    random.shuffle(combined)

    custom_indices, custom_labels = zip(*combined)
    custom_indices = list(custom_indices)
    custom_labels = list(custom_labels)

    return dataset, custom_indices, custom_labels

class FashionSequenceDataset(Dataset):
    def __init__(self, dataset, list_of_indices, list_of_labels):
        self.dataset = dataset
        self.list_of_indices = list_of_indices
        self.list_of_labels = list_of_labels

    def __len__(self):
        return len(self.list_of_indices)

    def __getitem__(self, idx):
        img_indices = self.list_of_indices[idx]
        lbls = self.list_of_labels[idx]
        imgs = [self.dataset[i][0] for i in img_indices]  # (1,28,28)
        return torch.stack(imgs), torch.tensor(lbls, dtype=torch.long)  # stack per shape (seq_len, 1, 28, 28)

def collate_fn(batch):
    imgs_seqs, lbls_seqs = zip(*batch)
    seq_lengths = [seq.shape[0] for seq in imgs_seqs]
    padded_imgs = pad_sequence(imgs_seqs, batch_first=True)  # (B, S, 1, 28, 28)
    padded_lbls = pad_sequence(lbls_seqs, batch_first=True, padding_value=-100)  # (B, S)
    return padded_imgs, padded_lbls, seq_lengths

### TRIDENT utils ###

def constraints_pred(pred, current_state, automata):
    row = automata["transition_matrix"][current_state - 1]
    device = pred.device
    constrains = torch.tensor(
        [0 if element not in automata["deadlock_states"] else float('-inf') for element in row],
        device=device
    )
    return pred + constrains

### Models ###

class FashionConvFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=2, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64*5*5, 512)
        self.drop1 = nn.Dropout2d(p=0.3)
        self.fc2 = nn.Linear(512, 128)
        self.drop2 = nn.Dropout2d(p=0.2)
        self.layerNorm = nn.LayerNorm(128)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.flatten(out)
        out = self.fc1(out)
        out = self.drop1(out)
        out = self.fc2(out)
        out = self.drop2(out)
        out = self.layerNorm(out)
        return out  # shape: (batch, 128)

class CNNLSTM(nn.Module):
    def __init__(self, feature_dim=128, hidden_dim=128, num_layers=1, num_classes=10):
        super().__init__()
        feature_extractor = FashionConvFeatureExtractor()
        state_dict = torch.load('../data/CNN/fashionConvNet_model_on_original_fmnist.pth')
        feature_extractor.load_state_dict(state_dict, strict=False)
        self.cnn = feature_extractor
        self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, seq_lengths):
        batch_size, seq_len, _, _, _ = x.size()
        x = x.view(batch_size * seq_len, 1, 28, 28)
        features = self.cnn(x)  # (batch*seq_len, feature_dim)
        features = features.view(batch_size, seq_len, -1)
        packed = nn.utils.rnn.pack_padded_sequence(features, seq_lengths, batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        logits = self.classifier(out)  # (batch, seq_len, num_classes)
        return logits

### Train functions ###

def train_cnn_lstm_with_frozen_cnn(
    model, dataset, custom_indices, custom_labels,
    cnn_weights_path, num_epochs=10, batch_size=16, lr=1e-3
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    cnn_state_dict = torch.load(cnn_weights_path, map_location=device)
    model.cnn.load_state_dict(cnn_state_dict, strict=False)

    for param in model.cnn.parameters():
        param.requires_grad = False

    seq_dataset = FashionSequenceDataset(dataset, custom_indices, custom_labels)
    dataloader = DataLoader(seq_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for padded_imgs, padded_lbls, seq_lengths in dataloader:
            padded_imgs = padded_imgs.to(device)      # (B, S, 1, 28, 28)
            padded_lbls = padded_lbls.to(device)      # (B, S)
            logits = model(padded_imgs, seq_lengths)  # (B, S, num_classes)
            loss = criterion(logits.view(-1, logits.size(-1)), padded_lbls.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), "../data/CNN-LSTM/model_weights.pth")
    print("Model trained and weights saved to 'model_weights.pth'.")

### Testing functions ###

def test_cnn_lstm(model, dataset, custom_indices, custom_labels, batch_size=16):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    model.to(device)

    seq_dataset = FashionSequenceDataset(dataset, custom_indices, custom_labels)
    dataloader = DataLoader(seq_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    total_labels = 0
    correct_labels = 0
    total_sequences = 0
    correct_sequences = 0

    with torch.no_grad():
        for padded_imgs, padded_lbls, seq_lengths in dataloader:
            padded_imgs = padded_imgs.to(device)
            padded_lbls = padded_lbls.to(device)
            logits = model(padded_imgs, seq_lengths)  # (B, S, num_classes)
            preds = logits.argmax(dim=-1)             # (B, S)

            mask = (padded_lbls != -100)
            correct = (preds == padded_lbls) & mask

            correct_labels += correct.sum().item()
            total_labels += mask.sum().item()

            for i in range(padded_lbls.size(0)):
                seq_mask = mask[i]
                if seq_mask.sum() == 0:
                    continue
                seq_correct = correct[i][seq_mask].all().item()
                correct_sequences += seq_correct
                total_sequences += 1

    single_label_acc = correct_labels / total_labels if total_labels > 0 else 0.0
    sequence_acc = correct_sequences / total_sequences if total_sequences > 0 else 0.0

    print(f"Single label accuracy: {single_label_acc:.4f}")
    print(f"Sequence accuracy:    {sequence_acc:.4f}")

def test_cnn_lstm_with_automata(test_custom_indices, model, test_dataset, automata, k, device="cpu"):

    device = torch.device(device)
    model.eval()
    model.to(device)

    total_labels = 0
    correct_labels = 0
    total_sequences = 0
    correct_sequences = 0

    results = []

    for seq_indices in test_custom_indices:
        imgs = [test_dataset[i][0].to(device) for i in seq_indices]
        imgs = torch.stack(imgs).unsqueeze(0)  # (1, seq_len, ...)
        seq_len = imgs.shape[1]

        with torch.no_grad():
            logits = model(imgs, [seq_len])  # (1, seq_len, num_classes)
        logits = logits.squeeze(0)  # (seq_len, num_classes)

        current_state = 1
        paths = [([current_state], 0.0, 0.0, automata["distances"][str(current_state)], [])]

        for t in range(seq_len):
            candidates = []
            step_logits = logits[t]  # (num_classes)
            for state_history, score, score_automata, _, pred_history in paths:
                constrained_pred = constraints_pred(
                    step_logits.unsqueeze(0), state_history[-1], automata
                )
                probs_automata = torch.nn.functional.softmax(constrained_pred, dim=-1)
                _, top_indices = torch.topk(probs_automata, k)
                probs_cnn = torch.nn.functional.softmax(step_logits.unsqueeze(0), dim=-1)

                for index in range(len(top_indices.squeeze().tolist())):
                    current_idx = top_indices.squeeze()[index].item()

                    new_score_automata = score_automata - np.log(probs_automata.squeeze(0)[current_idx].item())
                    new_score = score - np.log(probs_cnn.squeeze(0)[current_idx].item())

                    if np.isnan(new_score_automata):
                        new_score_automata = np.float64(np.inf)
                    if np.isnan(new_score):
                        new_score = np.float64(np.inf)

                    new_state = automata["transition_matrix"][state_history[-1] - 1][current_idx]
                    new_state_history = state_history + [new_state]

                    candidates.append((
                        new_state_history,
                        new_score,
                        new_score_automata,
                        automata["distances"][str(new_state_history[-1])],
                        pred_history + [current_idx]
                    ))

            paths = sorted(candidates, key=lambda x: x[2])[:k]

        paths = sorted(paths, key=lambda x: x[1])
        best_pred_labels = paths[0][4]
        results.append(best_pred_labels)

        true_labels = [test_dataset[i][1] for i in seq_indices]
        # Single label accuracy
        for p, t in zip(best_pred_labels, true_labels):
            total_labels += 1
            if p == t:
                correct_labels += 1
        # Sequence accuracy
        total_sequences += 1
        if best_pred_labels == true_labels:
            correct_sequences += 1

    single_label_acc = correct_labels / total_labels if total_labels > 0 else 0.0
    sequence_acc = correct_sequences / total_sequences if total_sequences > 0 else 0.0

    print(f"Single label accuracy: {single_label_acc:.4f}")
    print(f"Sequence accuracy:    {sequence_acc:.4f}")

    return results