import os
import time
import re
import csv

from together import Together
from together.error import RateLimitError
from tqdm import tqdm

from tasks import task_registry

import numpy as np

client = Together(api_key=os.environ["TOGETHER_AI_API_KEY"])

rng = np.random.default_rng(42)
la = (3, 9)
lb = (3, 9)

if os.path.exists('data/sft_mult_1000.tsv'):
    reader = csv.reader(open('data/sft_mult_1000.tsv', 'r'), delimiter='\t')
    num_skip = sum(1 for _ in reader)
else:
    num_skip = 0

with open('data/sft_mult_1000.tsv', 'a+') as f:
    writer = csv.DictWriter(f, fieldnames=['gt_answer', 'total_tokens', 'attempts', 'correct', 'prompt', 'target'], delimiter='\t')
    if num_skip < 1000:
        for i in tqdm(range(num_skip, 1000), initial=num_skip):
            if i == 0:
                writer.writeheader()
            prompt, s, _ = task_registry['synthetic_COT_mult'](rng, la, lb)

            correct = False
            max_attempts = 3
            attempts = 0
            total_tokens = 0
            failed_attempts = []
            while not correct and attempts < max_attempts:
                try:
                    response = client.chat.completions.create(
                        model='deepseek-ai/DeepSeek-R1',
                        messages=[{"role": "user", "content": prompt}],
                        max_tokens=32768
                    )
                except Exception as e:
                    if isinstance(e, RateLimitError):
                        time.sleep(10)
                        continue
                    else:
                        raise e

                response_txt = response.choices[0].message.content
                print(response_txt)
                ans = re.search(r'(?:<ans>)(\d*)(?:</ans>)', response_txt)
                correct = ans is not None and ans.group(1) == s 

                print(ans, s, correct)
                print('-' * 80)
                attempts += 1
                total_tokens += response.usage.total_tokens
                if not correct:
                    failed_attempts.append(response_txt)
            
            row = {
                'prompt': prompt,
                'target': response_txt,
                'gt_answer': s,
                'total_tokens': total_tokens,
                'attempts': attempts,
                'correct': correct
            }
            
            writer.writerow(row)
            print(row)
            print('=' * 80)
            print()

from datasets import Dataset, Features, Value
import pandas as pd

df = pd.read_csv('data/sft_mult_1000.tsv', delimiter='\t')
df = df.drop(df[df['total_tokens'] == 'total_tokens'].index)
df.to_csv('data/sft_mult_1000_cleaned.tsv', sep='\t', index=False)

ds = Dataset.from_csv('data/sft_mult_1000_cleaned.tsv', features=Features({
    'gt_answer': Value('string'),
    'total_tokens': Value('int64'),
    'attempts': Value('int64'),
    'correct': Value('bool'),
    'prompt': Value('string'),
    'target': Value('string')
}), delimiter='\t')
ds.save_to_disk('data/sft_mult_1000')
