import time
import tqdm
import os
import sys
import json

# import re-arcs
# Add the directory containing re_arc to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 're_arc')))

import argparse

import re_arc.dsl
from re_arc.dsl import *

import re_arc.utils
from re_arc.utils import *

import re_arc.generators
import re_arc.verifiers

from re_arc.main import *

from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
import re
from prettytable import PrettyTable

from opencv_contour import *

import random
import numpy as np
import torch
from transformers import set_seed

def set_random_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    set_seed(seed)

# Set the seed to ensure reproducibility
seed = 1230
set_random_seed(seed)

tokenizer_name = "./tokenizer_vs22_extendarctokens"
tokenizer = AutoTokenizer.from_pretrained(f"{tokenizer_name}")


def repad_2d_list(grid, pad_token='<arc_pad>', target_size=32):
    row_len = len(grid[0])
    pad_len = target_size - row_len
    padded_grid = [row + ['<arc_endxgrid>'] + [pad_token] * pad_len + ['<arc_nl>'] for row in grid]

    padding_row_y = ['<arc_endygrid>'] * row_len + ['<arc_endxygrid>'] + ['<arc_endygrid>'] * pad_len + ['<arc_nl>']
    padding_row = [pad_token] * row_len + ['<arc_endxgrid>'] + [pad_token] * pad_len + ['<arc_nl>']
    padded_grid.append(padding_row_y)
    while len(padded_grid) < target_size+1:
        padded_grid.append(padding_row)
    return padded_grid

def reshape_to_32x32(input_list):
    if len(input_list) != 32*32:
        raise ValueError("Input list must have exactly 32*32 elements")

    reshaped_list = [input_list[i*32:(i+1)*32] for i in range(32)]
    return reshaped_list

def print_table(grid):
    table = PrettyTable()
    for row in grid:
        table.add_row(row)
    print(table)

def reformat_arc(grid, print_grids=False):
    tokens_1d = tokenizer.tokenize(grid)
    #print(f"Number of tokens: {len(tokens_1d)}")
    if len(tokens_1d) > 32*32:
        tokens_1d = tokens_1d[:32*32]
    tokens_2d = unpad_2d_list(reshape_to_32x32(tokens_1d))
    padded_tokens_2d = repad_2d_list(tokens_2d)
    if print_grids:
        print("2D Tokens:")
        print_table(tokens_2d)
        print("\nRe-padded 2D Tokens:")
        print_table(padded_tokens_2d)

    # Flatten the 2D list and join the tokens into a single string
    flattened_tokens = [token for row in padded_tokens_2d for token in row]
    joined_tokens = ''.join(flattened_tokens)

    return joined_tokens

def reformat_arc_tokens(grid, print_grids=False):
    tokens_2d = grid
    padded_tokens_2d = repad_2d_list(tokens_2d)
    if print_grids:
        print("2D Tokens:")
        print_table(tokens_2d)
        print("\nRe-padded 2D Tokens:")
        print_table(padded_tokens_2d)

    # Flatten the 2D list and join the tokens into a single string
    flattened_tokens = [token for row in padded_tokens_2d for token in row]
    joined_tokens = ''.join(flattened_tokens)

    return joined_tokens

def replace_digits_with_arc(grid):
    return [[f'<arc_{element}>' for element in row] for row in grid]

# Function to pad a 2D list to target_size x target_size with pad_token
def pad_2d_list(grid, pad_token='<pad>', target_size=32):
    padded_grid = [row + [pad_token] * (target_size - len(row)) for row in grid]
    while len(padded_grid) < target_size:
        padded_grid.append([pad_token] * target_size)
    return padded_grid


max_input_length = 1124
max_target_length = 1124
target_shape = (33, 34)

def pad_and_flatten(input_type_ids, target_shape, final_length):
    # Get the current shape of the input_type_ids
    current_shape = input_type_ids.shape

    # Create a padded array of the target shape filled with zeros
    padded_array = np.zeros(target_shape, dtype=int)

    # Copy the content of input_type_ids into the top-left corner of the padded array
    padded_array[:current_shape[0], :current_shape[1]] = input_type_ids

    # Flatten the array
    flattened_array = padded_array.flatten()

    # Add a 0 at the beginning and the end
    padded_and_flattened = np.pad(flattened_array, (1, 1), 'constant', constant_values=0)

    # Ensure the final array has the correct length
    if len(padded_and_flattened) != final_length:
        raise ValueError(f"Final length of padded array is {len(padded_and_flattened)}, expected {final_length}.")

    return padded_and_flattened

def preprocess_example(example):
    """
    Convert 2D grid input/output to 1D sequences of special tokens.
    """
    input_grid = replace_digits_with_arc(example['input'])
    output_grid = replace_digits_with_arc(example['output'])

    result_string_input = "<s>"+reformat_arc_tokens(input_grid, print_grids=False)+"</s>"
    result_string_output = "<s>"+reformat_arc_tokens(output_grid)+"</s>"

    model_inputs = tokenizer(result_string_input, max_length=max_input_length, padding="max_length", truncation=True)

    # encode the summaries
    labels = tokenizer(result_string_output, max_length=max_target_length, padding="max_length", truncation=True).input_ids
    # No masking

    model_inputs["labels"] = labels
    model_inputs["input"] = result_string_input
    model_inputs["output"] = result_string_output

    i_type_ids = generate_input_type_ids_multi(np.array(example['input']), visualize=False)
    #paint_grid_with_boxes(np.array(example['input']), i_type_ids)

    o_type_ids = generate_input_type_ids_multi(np.array(example['output']), visualize=False)
    #paint_grid_with_boxes(np.array(example['output']), o_type_ids)

    # Pad and flatten to match size 1124
    model_inputs["input_type_ids"] = pad_and_flatten(i_type_ids, target_shape, max_input_length)
    model_inputs["output_type_ids"] = pad_and_flatten(o_type_ids, target_shape, max_target_length)

    return model_inputs

def is_grid_extra(
    grid: Any
) -> bool:
    """
    returns True if and only if argument is a valid grid
    """
    if not isinstance(grid, tuple):
        return False
    if not len(grid) > 0:
        return False
    if not len(grid) <= 30: # also check #rows
        return False
    if not all(isinstance(r, tuple) for r in grid):
        return False
    if not all(0 < len(r) <= 30 for r in grid):
        return False
    if not len(set(len(r) for r in grid)) == 1:
        return False
    if not all(all(isinstance(x, int) for x in r) for r in grid):
        return False
    if not all(all(0 <= x <= 9 for x in r) for r in grid):
        return False
    return True


def generate_single_dataset_hf(
    task_idx: int = 0,
    path: str = 're_arc',
    seed: int = 42,
    n_examples: int = 1000000,
    testsize: int = 1000,
    diff_lb: float = 0,
    diff_ub: float = 1
) -> None:
    """
    generates dataset

    path: which folder to save data to
    seed: for deterministic generation / reproducibility
    n_examples: number of examples per task
    diff_lb: lower bound for difficulty
    diff_ub: upper bound for difficulty
    """
    """
    set_seed(seed)
    os.makedirs(path, exist_ok=True)
    tasks_path = os.path.join(path, 'tasks')
    os.makedirs(tasks_path, exist_ok=True)
    """
    random.seed(seed)

    generators_mapper = get_generators()
    verifiers_mapper = get_verifiers()
    keys = sorted(generators_mapper.keys())
    k = len(keys)

    metadata = dict()
    key = keys[task_idx]
    print(f"TaskID:{key}")

    generator = generators_mapper[key]
    verifier = verifiers_mapper[key]
    seen = set()
    examples = []
    stats = {
        'n_generations': 0, 'n_verified': 0, 'n_nondegenerate': 0,
        'rng_difficulties': [], 'pso_difficulties': []
    }

    max_attempts = 40 * n_examples
    attempts = 0
    start = time.time()
    # Initialize tqdm progress bar with the total number of iterations expected (max_attempts)
    with tqdm.tqdm(total=n_examples) as pbar:
        while len(examples) < n_examples and attempts < max_attempts:
            example, identifier, success = None, None, True
            attempts += 1
            try:
                example = generator(diff_lb, diff_ub)
                assert is_grid_extra(example['input'])
                assert is_grid_extra(example['output'])
                identifier = hash(example['input'])
                stats['n_generations'] += 1
            except:
                success = False
            try:
                assert success and verifier(example['input']) == example['output']
                stats['n_verified'] += 1
            except:
                success = False
            try:
                assert success and example['input'] != example['output']
                stats['n_nondegenerate'] += 1
            except:
                success = False
            if success and identifier not in seen:
                # Should strictly have no duplicate
                examples.append(example)
                seen.add(identifier)
                #stats['rng_difficulties'].append(get_rng_difficulty(example))
                #stats['pso_difficulties'].append(get_pso_difficulty(example))

                # Update the progress bar
                pbar.update(1)

            # Update progress bar with attempts
            pbar.set_postfix(attempts=attempts)

    end = time.time()
    stats['runtime'] = end - start

    print(stats)

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_list(examples)
    print(dataset)

    # Apply preprocessing
    processed_dataset = dataset.map(preprocess_example, batched=False)
    #print(processed_dataset)
    #print(processed_dataset[0])


    # Split the dataset into train, validation, and test sets
    train_test_split = processed_dataset.train_test_split(test_size=testsize*2, shuffle=True, seed=42)
    train_dataset = train_test_split['train']
    test_val_split = train_test_split['test'].train_test_split(test_size=testsize, shuffle=True, seed=42)
    val_dataset = test_val_split['train']
    test_dataset = test_val_split['test']

    # Combine into a DatasetDict
    final_datasets = DatasetDict({
        'train': train_dataset,
        'validation': val_dataset,
        'test': test_dataset
    })

    return key, final_datasets
    """
    # Create a directory for this task and save the datasets
    task_output_path = os.path.join(tasks_path, key)
    os.makedirs(task_output_path, exist_ok=True)
    final_datasets.save_to_disk(task_output_path)


    metadata[key] = stats
    with open(os.path.join(task_output_path, f'gen_stats.json'), 'w') as fp:
        json.dump(metadata, fp)
    """


def test_gen1M():
    # Record the start time
    sbatch_start_time = time.time()

    task_id, final_datasets = generate_single_dataset_hf( task_idx=0,path='/scratch/', seed=1230, n_examples=10, testsize=1)
    final_datasets.save_to_disk(f'./{task_id}_10_trial')

    # Record the end time
    sbatch_end_time = time.time()

    # Calculate the total runtime
    sbatch_runtime = sbatch_end_time - sbatch_start_time

    # Print the runtime
    print(f"Total runtime: {sbatch_runtime} seconds")
    # Convert runtime to hours, minutes, and seconds
    hours, rem = divmod(sbatch_runtime, 3600)
    minutes, seconds = divmod(rem, 60)

    # Print the runtime
    print(f"Total runtime: {int(hours)} hrs {int(minutes)} min {seconds:.2f} seconds")

#test_gen1M()

def run_gen(task_idx):
    task_id, final_datasets = generate_single_dataset_hf( task_idx=task_idx,path='/scratch/', seed=1230, n_examples=1000000, testsize=1000)
    final_datasets.save_to_disk(f'./arc_x2y_datasets/{task_idx}_1M_withrep')

    print(final_datasets)

