import random
import re
import tqdm
from datasets import load_dataset
from datasets import Dataset, DatasetDict
import argparse
import os
import json
try:
  import yaml
except Exception:
  yaml = None
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="rrm_config.yaml")
parser.add_argument("--input_dataset", type=str)
parser.add_argument("--output_dataset", type=str)
args = parser.parse_args()

# Load YAML config if available and merge with CLI args (CLI takes precedence if provided)
config_input = None
config_output = None
if args.config and os.path.exists(args.config) and yaml is not None:
  with open(args.config, "r") as f:
    try:
      cfg = yaml.safe_load(f) or {}
      config_input = cfg.get("input_dataset")
      config_output = cfg.get("output_dataset")
    except Exception:
      cfg = {}

input_dataset_arg = args.input_dataset or config_input
output_dataset_arg = args.output_dataset or config_output


def get_fields(messages) -> dict[str, str]:
  delimiters = r'\[CONTEXT\]|\[RESPONSE A\]|\[RESPONSE B\]'
  # Split the string
  result = re.split(delimiters, messages[0]['content'])
  # Remove empty strings that may result from consecutive delimiters or delimiters at the start/end
  result = [x for x in result if x]
  assert len(result) == 3
  assert messages[1]['content'] in ['A', 'B', 'Same']
  if messages[1]['content'] == 'A':
    return {
        'context': result[0],
        'response_w': result[1],
        'response_l': result[2],
        'neutral': False
    }
  elif messages[1]['content'] == 'B':
    return {
        'context': result[0],
        'response_l': result[1],
        'response_w': result[2],
        'neutral': False
    }
  else:
    return {
        'context': result[0],
        'response_w': result[1],
        'response_l': result[2],
        'neutral': True
    }


def to_messages(fields: dict[str, str]) -> str:
  context = fields['context']
  neutral = fields["neutral"]
  if random.randint(0,1):
    response_a = fields['response_w']
    response_b = fields['response_l']
    label = "A"
  else:
    response_a = fields['response_l']
    response_b = fields['response_w']
    label = "B"
  if neutral:
    label = "Same"
  message_0 = {
      "role": "user",
      "content": f"[CONTEXT]{context}" +
          f"[RESPONSE A]{response_a}" +
          f"[RESPONSE B]{response_b}"
  }
  message_1 = {
      "role": "assistant",
      "content": label
  }
  return [message_0, message_1]


def get_augmented(data):
  data_i = data
  data_j = data_i.copy()
  random.shuffle(data_j)
  data_k = data_j.copy()
  random.shuffle(data_k)
  for ex_i, ex_j, ex_k in zip(data_i, data_j, data_k):
    xi = ex_i['context']
    xj = ex_j['context']
    xk = ex_k['context']
    ywi = ex_i['response_w']
    ywj = ex_j['response_w']
    ywk = ex_k['response_w']
    yli = ex_i['response_l']
    ylj = ex_j['response_l']
    ylk = ex_k['response_l']
    # xi_ywi_ywj
    yield {
        "context": xi,
        "response_w": ywi,
        "response_l": ywj,
        "neutral": False
    }
    # xi_ywi_ywk
    yield {
        "context": xi,
        "response_w": ywi,
        "response_l": ywk,
        "neutral": False
    }
    # xi_ywi_ylj
    yield {
        "context": xi,
        "response_w": ywi,
        "response_l": ylj,
        "neutral": False
    }
    # xi_ywi_ylk
    yield {
        "context": xi,
        "response_w": ywi,
        "response_l": ylk,
        "neutral": False
    }
    # xi_yli_ywj
    yield {
        "context": xi,
        "response_w": yli,
        "response_l": ywj,
        "neutral": False
    }
    # xi_yli_ywk
    yield {
        "context": xi,
        "response_w": yli,
        "response_l": ywk,
        "neutral": False
    }
    # xi_yli_ylj
    yield {
        "context": xi,
        "response_w": yli,
        "response_l": ylj,
        "neutral": False
    }
    # xi_yli_ylk
    yield {
        "context": xi,
        "response_w": yli,
        "response_l": ylk,
        "neutral": False
    }
    # xi_ywj_ylj
    yield {
        "context": xi,
        "response_w": ywj,
        "response_l": ylj,
        "neutral": True
    }
    # xi_ywk_ylk
    yield {
        "context": xi,
        "response_w": ywk,
        "response_l": ylk,
        "neutral": True
    }
    # xi_ywj_ywk
    yield {
        "context": xi,
        "response_w": ywj,
        "response_l": ywk,
        "neutral": True
    }
    # xi_ywj_ylk
    yield {
        "context": xi,
        "response_w": ywj,
        "response_l": ylk,
        "neutral": True
    }
    # xi_ywk_ylj
    yield {
        "context": xi,
        "response_w": ywk,
        "response_l": ylj,
        "neutral": True
    }
    # xi_ylj_ylk
    yield {
        "context": xi,
        "response_w": ylj,
        "response_l": ylk,
        "neutral": True
    }


def process_data(data):
  all_fields = []
  for d in tqdm.tqdm(data):
    try:
      all_fields.append(get_fields(d['messages']))
    except:
      print(d['messages'])
  for fields in tqdm.tqdm(get_augmented(all_fields)):
    yield to_messages(fields)


if not input_dataset_arg:
  raise ValueError("No input dataset provided. Specify via --input_dataset or rrm_config.yaml")
if input_dataset_arg.endswith((".json", ".jsonl")) and os.path.exists(input_dataset_arg):
  ds = load_dataset('json', data_files=input_dataset_arg, split='train')
else:
  ds = load_dataset(input_dataset_arg, split='train')
processed_messages = list(process_data(list(ds)))
# Create dataset directly from the processed messages
dataset_data = [{'messages': m} for m in processed_messages]
dataset = Dataset.from_list(dataset_data)
if not output_dataset_arg:
  # Default to local JSONL next to input when no output specified
  default_out = os.path.splitext(input_dataset_arg)[0] + ".augmented.jsonl"
  output_dataset_arg = default_out

# Print first 2 samples
print("\n📝 First 2 augmented samples:")
print(f"Dataset length: {len(dataset)}")

for i in range(min(2, len(dataset_data))):
    print(f"\n--- Sample {i+1} ---")
    row = dataset_data[i]
    messages = row['messages']
    print(f"Message 0 (user): {messages[0]['content'][:200]}...")
    print(f"Message 1 (assistant): {messages[1]['content']}")

# Save locally if output is a file path
if output_dataset_arg.endswith('.jsonl'):
  # write JSONL with one object per line
  with open(output_dataset_arg, 'w') as f:
    for row in dataset:
      f.write(json.dumps({'messages': row['messages']}, ensure_ascii=False) + "\n")
  print(f"\n✅ Saved {len(dataset)} augmented samples to: {output_dataset_arg}")
elif output_dataset_arg.endswith('.json'):
  # write a single JSON containing list of rows
  with open(output_dataset_arg, 'w') as f:
    json.dump([{ 'messages': row['messages'] } for row in dataset], f, ensure_ascii=False)
  print(f"\n✅ Saved {len(dataset)} augmented samples to: {output_dataset_arg}")
else:
  # Default to JSONL if no extension specified
  default_output = output_dataset_arg + ".jsonl"
  with open(default_output, 'w') as f:
    for row in dataset:
      f.write(json.dumps({'messages': row['messages']}, ensure_ascii=False) + "\n")
  print(f"\n✅ Saved {len(dataset)} augmented samples to: {default_output}")

