import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import torch.nn as nn

class TransformerCharPredictor(nn.Module):
    def __init__(self, vocab_size, dim=5, heads=1, layers=3, k=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model=dim, nhead=heads, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=layers)

        # Final Linear Layer to project to k output symbols
        self.output = nn.Linear(dim, k)  # Project to k symbols (size of k)

    def generate_future_mask(self, size):
        # Causal mask: prevent attending to future tokens
        return torch.triu(torch.ones(size, size), diagonal=1).bool()

    def forward(self, x):
        # Ensure x has a batch dimension
        if x.dim() == 1:  # If x is 1D (seq_len), add a batch dimension
            x = x.unsqueeze(0)  # (1, seq_len)

        # x: (batch, seq_len)
        x_embed = self.embedding(x)  # (batch, seq_len, dim)
        tgt_mask = self.generate_future_mask(x.size(1)).to(x.device)
        # print("Target Mask:\n", tgt_mask.int())  # Print as int for clarity
        x_decoded = self.decoder(tgt=x_embed, memory=torch.zeros_like(x_embed), tgt_mask=tgt_mask)
        # Project output to k dimensions (symbols)
        logits = self.output(x_decoded)  # (batch, seq_len, k)
        probs = torch.sigmoid(logits)    # Probabilities for each of the k symbols (sigmoid activation)

        # Return both probabilities and the decoded output (to feed into the next step)
        return probs, x_decoded


def train(model, optimizer, device, source, target, k, epochs=10):
    model.train()
    criterion = CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        total_elements = 0
        data = tqdm(zip(source, target), total=len(source))
        print(f"Epoch {epoch + 1}/{epochs}")
        for src, tgt in data:
            src = src.to(device)
            tgt = tgt.to(device)

            optimizer.zero_grad()
            output, _ = model(src)

            # Reshape output and target for loss calculation
            output = output.view(-1, k)  # (batch * seq_len, k)
            tgt = tgt.view(-1)           # (batch * seq_len)

            loss = criterion(output, tgt)
            loss.backward()
            optimizer.step()

            # Take argmax of the output
            output = torch.argmax(output, dim=-1).squeeze(0)
            # Correct only if the output matches the target exactly
            correct = torch.equal(output, tgt)
            if correct:
                total_correct += 1
            total_elements += 1


            total_loss += loss.item()

        # Calculate total accuracy
        accuracy = total_correct / total_elements * 100

        print(f"Epoch {epoch + 1} Loss: {total_loss / len(source):.4f}, Accuracy: {accuracy:.2f}%")

def evaluate(model, source, target, device):
    model.eval()
    total_correct = 0
    total_elements = 0

    with torch.no_grad():
        for src, tgt in tqdm(zip(source, target), total=len(source)):
            src = src.to(device)
            tgt = tgt.to(device)

            output, _ = model(src)
            # Take argmax of the output
            output = torch.argmax(output, dim=-1).squeeze(0)

            # print(f"Output: {output}")
            # print(f"Target: {tgt}")

            # Correct only if the output matches the target exactly
            correct = torch.equal(output, tgt)
            if correct:
                total_correct += 1
            total_elements += 1

    # Calculate total accuracy
    total_accuracy = total_correct / total_elements * 100
    print(f"Total Accuracy: {total_accuracy:.2f}%")
    return total_accuracy

# Helper function to load and preprocess data
def load_and_preprocess_data(file_name, range_name):
    with open(f"data/{file_name}/{range_name}_src.txt", "r") as f:
        source = [line.strip() for line in f.readlines()]
    with open(f"data/{file_name}/{range_name}_tgt.txt", "r") as f:
        target = [line.strip() for line in f.readlines()]
    source = [[0 if char == 'a' else 1 for char in line] for line in source]
    target = [[0 if char == '0' else 1 for char in line] for line in target]
    source = [torch.tensor(line) for line in source]
    target = [torch.tensor(line) for line in target]
    return source, target

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    heads = [1]
    dims = [16,64,512]
    lrs = [0.0001, 0.00001]
    vocab_size = 2
    output_size = 2
    epochs = 10
    layers = 10

    # Set the random seed for reproducibility
    torch.manual_seed(42)

    for k in range(3,13):
        # Set the random seed for reproducibility
        torch.manual_seed(42)

        file_name = f"L{k}"
        print(f"Training with k = {k} and layers = {layers}")

        # Load source from file
        with open(f"data/{file_name}/0_to_50_src.txt", "r") as f:
            source = [line.strip() for line in f.readlines()]
        # Load target from file
        with open(f"data/{file_name}/0_to_50_tgt.txt", "r") as f:
            target = [line.strip() for line in f.readlines()]

        # Convert source and target to tensors
        # Map 'a' to 0 and 'b' to 1
        source = [[0 if char == 'a' else 1 for char in line] for line in source]
        target = [[0 if char == '0' else 1 for char in line] for line in target]
        source = [torch.tensor(line) for line in source]
        target = [torch.tensor(line) for line in target]

        # Split the data into training and validation sets
        source_train = source[:8000]
        target_train = target[:8000]
        source_eval = source[8000:]
        target_eval = target[8000:]

        # clear the log file
        with open(f"models/{file_name}/training_log.txt", "w") as f:
            f.write("")

        log_file = open(f"models/{file_name}/training_log.txt", "a")
        success = False
        final_heads = 0
        final_dim = 0
        final_lr = 0
        model = None
        best_model = None
        best_accuracy = 0
        for head in heads:
            if success:
                break
            for dim in dims:
                if success:
                    break
                for lr in lrs:
                    if success:
                        break
                    model = TransformerCharPredictor(vocab_size, dim, head, layers, output_size)
                    model.to(device)
                    optimizer = torch.optim.Adam(model.parameters(), lr)
                    model_name = f"transformer_model_dim_{dim}_heads_{head}_layers_{layers}.pth"


                    print("Testing with dim:", dim, "heads:", head, "lr:", lr)
                    log_file.write(f"Testing with dim: {dim}, heads: {head}, lr: {lr}\n")
                    log_file.flush()

                    model.train()
                    criterion = CrossEntropyLoss()

                    for epoch in range(epochs):
                        total_loss = 0
                        total_correct = 0
                        total_elements = 0
                        data = tqdm(zip(source_train, target_train), total=len(source_train))
                        print(f"Epoch {epoch + 1}/{epochs}")
                        for src, tgt in data:
                            src = src.to(device)
                            tgt = tgt.to(device)

                            optimizer.zero_grad()
                            output, _ = model(src)

                            # Reshape output and target for loss calculation
                            output = output.view(-1, output_size)  # (batch * seq_len, k)
                            tgt = tgt.view(-1)                     # (batch * seq_len)

                            loss = criterion(output, tgt)
                            loss.backward()
                            optimizer.step()

                            # Take argmax of the output
                            output = torch.argmax(output, dim=-1).squeeze(0)
                            # Correct only if the output matches the target exactly
                            correct = torch.equal(output, tgt)
                            if correct:
                                total_correct += 1
                            total_elements += 1

                            total_loss += loss.item()

                        # Calculate total accuracy
                        accuracy = total_correct / total_elements * 100

                        print(f"Epoch {epoch + 1} Loss: {total_loss / len(source_train):.4f}, Accuracy: {accuracy:.2f}%")

                        # Write to log file
                        log_file.write(f"Epoch {epoch + 1}, Loss: {total_loss / len(source_train):.4f}, Accuracy: {accuracy:.2f}%\n")
                        log_file.flush()

                        # Evaluate the model on the test set
                        accuracy = evaluate(model, source_eval, target_eval, device)

                        if accuracy == 100:
                            print(f"Model {model_name} reached 100% accuracy")
                            success = True
                            final_heads = head
                            final_dim = dim
                            final_lr = lr
                            break
                        elif accuracy >= best_accuracy:
                            best_accuracy = accuracy
                            best_model = model
        if success:
            torch.save(model.state_dict(), f"models/{file_name}/{model_name}")

            # write heads, dim, lr to log file
            log_file.write(f"Model {model_name} reached 100% accuracy\n")
            log_file.write(f"heads: {final_heads}, dim: {final_dim}, lr: {final_lr}\n")


            # Load and preprocess data for each range
            source_1, target_1 = load_and_preprocess_data(file_name, "0_to_50")
            source_2, target_2 = load_and_preprocess_data(file_name, "50_to_100")
            source_3, target_3 = load_and_preprocess_data(file_name, "100_to_150")
            source_4, target_4 = load_and_preprocess_data(file_name, "150_to_200")

            # evaluate the model on the test set
            accuracy_1 = evaluate(model, source, target, device)
            accuracy_2 = evaluate(model, source_2, target_2, device)
            accuracy_3 = evaluate(model, source_3, target_3, device)
            accuracy_4 = evaluate(model, source_4, target_4, device)

            # write to log file
            log_file.write(f"Final Accuracy on bin [0, 50]: {accuracy_1:.2f}%\n")
            log_file.write(f"Final Accuracy on bin [50, 100]: {accuracy_2:.2f}%\n")
            log_file.write(f"Final Accuracy on bin [100, 150]: {accuracy_3:.2f}%\n")
            log_file.write(f"Final Accuracy on bin [150, 200]: {accuracy_4:.2f}%\n")

            log_file.close()
        else:
            # record failure in log file
            log_file.write(f"All models failed to reach 100% accuracy\n")
            torch.save(model.state_dict(), f"models/{file_name}/{model_name}_best_on_training")

            # write heads, dim, lr to log file
            log_file.write(f"heads: {final_heads}, dim: {final_dim}, lr: {final_lr}\n")


            # Load and preprocess data for each range
            source_1, target_1 = load_and_preprocess_data(file_name, "0_to_50")
            source_2, target_2 = load_and_preprocess_data(file_name, "50_to_100")
            source_3, target_3 = load_and_preprocess_data(file_name, "100_to_150")
            source_4, target_4 = load_and_preprocess_data(file_name, "150_to_200")

            # evaluate the model on the test set
            accuracy_1 = evaluate(best_model, source, target, device)
            accuracy_2 = evaluate(best_model, source_2, target_2, device)
            accuracy_3 = evaluate(best_model, source_3, target_3, device)
            accuracy_4 = evaluate(best_model, source_4, target_4, device)

            # write to log file
            log_file.write(f"Final Accuracy on bin [0, 50]: {accuracy_1:.2f}%\n")
            log_file.write(f"Final Accuracy on bin [50, 100]: {accuracy_2:.2f}%\n")
            log_file.write(f"Final Accuracy on bin [100, 150]: {accuracy_3:.2f}%\n")
            log_file.write(f"Final Accuracy on bin [150, 200]: {accuracy_4:.2f}%\n")

            log_file.close()

