import json
import re
import tempfile
import subprocess
import os
from multiprocessing import Pool, cpu_count
from functools import partial
from tqdm import tqdm
import gc
import psutil
import time
from code_utils import sanitize_code
import sys
response_num = 8
process_batch = 'gemini-2.5-pro'
num = 0

def clean_response(text):
    # Remove <think> tags and their content
    return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)

def extract_triple_single_quote_json(text):
    # Regex pattern to match content inside ''' '''
    pattern = r"```json\n(.*?)```"
    matches = re.findall(pattern, text, re.DOTALL)
    return matches

def extract_triple_single_quote_code(text):
    # Regex pattern to match content inside ''' '''
    pattern = r"```python\n(.*?)```"
    matches = re.findall(pattern, text, re.DOTALL)
    return matches

def no_install(*args, **kwargs):
    raise RuntimeError("No install")

def run_limited_subprocess(cmd, mem_limit_mb=2048, timeout=300, cwd=None, env=None):
    if env is None:
        env = os.environ.copy()  
    proc = subprocess.Popen(
        cmd,
        cwd=cwd,
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    p = psutil.Process(proc.pid)
    start_time = time.time()

    try:
        while proc.poll() is None:
            try:
                mem_mb = p.memory_info().rss / (1024 * 1024)
                if mem_mb > mem_limit_mb:
                    proc.kill()
                    print(f"[MEM KILL] PID {proc.pid} memory {mem_mb:.2f}MB > {mem_limit_mb}MB")
                    out, err = proc.communicate(timeout=2)
                    if out: print("[STDOUT]\n", out)
                    if err: print("[STDERR]\n", err)
                    return False
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                break

            if time.time() - start_time > timeout:
                proc.kill()
                print(f"[TIMEOUT] PID {proc.pid} exceeded {timeout}s")
                out, err = proc.communicate(timeout=2)
                if out: print("[STDOUT]\n", out)
                if err: print("[STDERR]\n", err)
                return False
            time.sleep(0.25)

        out, err = proc.communicate(timeout=2)
        if proc.returncode != 0:
            if err: 
                raise Exception(err)
            return False
        return True
    finally:
        try:
            proc.kill()
        except Exception:
            pass


def process_video(args):
    """Process a single video generation task"""
    id, video_id, response = args
    tmp_path = None
    save_path = f"save_video/{process_batch}/video_{id}"
    video_name = f"video_{id}_{video_id}.mp4"
    response = clean_response(response)
    question = extract_triple_single_quote_code(response)

    if not question:
        return False        
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(f'code/{process_batch}', exist_ok=True)
    try:
        with tempfile.NamedTemporaryFile(suffix=".py", mode='w', delete=False, dir=f'code/{process_batch}') as tmp:
            code = question[0]
            for i in range(len(question)):
                if len(question[i]) > len(code):
                    code = question[i]
            code = code.replace('\"', '"')
            code = sanitize_code(code)
            pattern = r"^(\s*).*pip.*$"
            code = re.sub(pattern, r"\1pass", code, flags=re.M)
            tmp.write(code)
            tmp_path = tmp.name
        s = f'python3 {tmp_path} {os.path.join(save_path, video_name)}'
        with open(f'log/code_running_{process_batch}.txt', "a", encoding="utf-8") as f:
            if not s.endswith("\n"):         
                s += "\n"
            f.write(s)
        # Run the video generation
        result = run_limited_subprocess(
            [sys.executable, tmp_path, os.path.join(save_path, video_name)],
            mem_limit_mb=3584,
            timeout=300,
            env=os.environ.copy(),
            )

        del result
        gc.collect()
        
        return True
            
            
    except Exception as e:
        try:
            s = f'video_{id}_{video_id}: {e.text}: {str(e)}'
        except:
            s = f'video_{id}_{video_id}: {str(e)}'
        with open(f'log/code_error_{process_batch}.txt', "a", encoding="utf-8") as f:
            s.replace('\n', ' ')
            if not s.endswith("\n"):          
                s += "\n"
            f.write(s)
        pass


def main():
    # Read all data first
    all_tasks = []
    
    with open(f'viphy_val_getcode_responses8_{process_batch}.jsonl', 'r') as f:
        for id, i in enumerate(f):
            item = json.loads(i)
            print(f"Processing item {id} with {len(item['responses'])} responses")
            
            for video_id, response in enumerate(item['responses']):
                if response:
                    all_tasks.append((item['video_index'], video_id, response))
    
    print(f"Total tasks to process: {len(all_tasks)}")
    num_processes = max(1, int(cpu_count() * 0.75))
    print(f"Using {num_processes} processes")
    
    # Process videos in parallel
    with Pool(processes=num_processes, maxtasksperchild=1) as pool:
        for _ in tqdm(pool.imap(process_video, all_tasks, chunksize=1), total=len(all_tasks)):
            pass
def process_1():
    all_tasks = []
    
    with open(f'viphy_val_getcode_responses1_{process_batch}.json', 'r') as f:
        item = json.load(f)
        for i in range(len(item)):
            all_tasks.append((item[i]['id'].replace('video_', ''), 0, item[i]['response']))          
    
    print(f"Total tasks to process: {len(all_tasks)}")

    num_processes = max(1, int(cpu_count() * 0.75))
    print(f"Using {num_processes} processes")
    
    # Process videos in parallel
    with Pool(processes=num_processes, maxtasksperchild=1) as pool:
        for _ in tqdm(pool.imap(process_video, all_tasks, chunksize=1), total=len(all_tasks)):
            pass

if __name__ == "__main__":
    subprocess.check_call = no_install
    if response_num == 8:
        main()
    elif response_num == 1:
        process_1()
    
     
    
           
            
