import logging
logging.basicConfig(level='ERROR')

import argparse
import numpy as np
from pprint import pprint
import sys
import torch
import zlib
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm
import json
import os
from pathlib import Path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        

def parse_commoncrawl(wet_file):
    """
    Quick and ugly parsing of a WET file.
    Tested for the May 2021 crawl.
    """
    with open(wet_file) as f:
        lines = f.readlines() 
    
    start_idxs = [i for i in range(len(lines)) if "WARC/1.0" in lines[i]]
    
    all_eng = ""

    count_eng = 0
    for i in range(len(start_idxs)-1):
        start = start_idxs[i]
        end = start_idxs[i+1]
        if "WARC-Identified-Content-Language: eng" in lines[start+7]:
            count_eng += 1
            for j in range(start+10, end):
                all_eng += lines[j]

    return all_eng

def write_list_to_json_file(directory, filename, data_list):
    """
    Write a Python list to a JSON file in the specified directory.
    
    Args:
        directory (str): The directory where the JSON file will be created
        filename (str): The name of the JSON file (including .json extension)
        data_list (list): The Python list to be written to the JSON file
    """
    Path(directory).mkdir(parents=True, exist_ok=True)
    
    filepath = os.path.join(directory, filename)
    
    with open(filepath, 'w') as json_file:
        json.dump(data_list, json_file, indent=4)
    
    print(f"List successfully written to {filepath}")

def main():
    print(f"using device: {device}")

    if args.internet_sampling:
        print("Loading common crawl...")
        cc = parse_commoncrawl(args.wet_file)

    seq_len = 256

    top_k = 40

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.padding_side = "left" 
    tokenizer.pad_token = tokenizer.eos_token
    
    samples = []
    scores = {"XL": [], "S": [], "Lower": [], "zlib": []}

    num_batches = int(np.ceil(args.N / args.batch_size))
    with tqdm(total=args.N) as pbar:
        for i in range(num_batches):
            if args.internet_sampling:
                input_len = 35
                input_ids = []
                attention_mask = []

                while len(input_ids) < args.batch_size:
                    r = np.random.randint(0, len(cc))
                    prompt = " ".join(cc[r:r+100].split(" ")[1:-1])

                    inputs = tokenizer(prompt, return_tensors="pt", max_length=input_len, truncation=True)
                    if len(inputs['input_ids'][0]) == input_len:
                        input_ids.append(inputs['input_ids'][0])
                        attention_mask.append(inputs['attention_mask'][0])

                inputs = {'input_ids': torch.stack(input_ids), 
                          'attention_mask': torch.stack(attention_mask)}

                prompts = tokenizer.batch_decode(inputs['input_ids'], skip_special_tokens=True)
                prompts[0] = prompts[0] + "?"
                samples.append(prompts[0])

    output_dir = ""
    json_filename = "Random_Crawl_10000.json"
    write_list_to_json_file(output_dir, json_filename, samples)

def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--N', type=int, default=1000, help="Number of samples to generate")
    parser.add_argument('--batch-size', type=int, default=1, help="Batch size for generation")
    parser.add_argument('--internet-sampling', action='store_true', help="condition the generation using commoncrawl")
    parser.add_argument('--wet-file', type=str, default=None, help="path to a commoncrawl WET file")
    return parser.parse_args(argv)

if __name__ == '__main__':
    args = parse_arguments(sys.argv[1:])
    main()
