"""
GLUE Dataset Download Script

This script downloads and prepares all GLUE benchmark datasets for training.
It handles the special case of MRPC which requires manual download due to licensing.

Note: For legal reasons, we are unable to host MRPC.
You can either use the version hosted by the SentEval team, which is already tokenized, 
or you can download the original data from:
https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi

For Windows users, you can run the .msi file. For Mac and Linux users, consider an external 
library such as 'cabextract' (see below for an example).

Example MRPC extraction:
mkdir MRPC
cabextract MSRParaphraseCorpus.msi -d MRPC
cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
rm MRPC/_*
rm MSRParaphraseCorpus.msi

Author: Feature Distillation Research Team
License: Apache 2.0
"""

import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile

# GLUE task definitions
TASKS = ["CoLA", "SST", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic"]

# URL mappings for GLUE datasets
TASK2PATH = {
    "CoLA": 'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',
    "SST": 'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
    "QQP": 'https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip',
    "STS": 'https://dl.fbaipublicfiles.com/glue/data/STS-B.zip',
    "MNLI": 'https://dl.fbaipublicfiles.com/glue/data/MNLI.zip',
    "QNLI": 'https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip',
    "RTE": 'https://dl.fbaipublicfiles.com/glue/data/RTE.zip',
    "WNLI": 'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip',
    "diagnostic": 'https://dl.fbaipublicfiles.com/glue/data/AX.tsv'
}

# MRPC URLs (from SentEval)
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'


def download_and_extract(task, data_dir):
    """
    Download and extract a GLUE dataset.
    
    Args:
        task: Name of the GLUE task
        data_dir: Directory to save the dataset
    """
    print(f"Downloading and extracting {task}...")
    
    if task == "MNLI":
        print("\tNote (12/10/20): This script no longer downloads SNLI. "
              "You will need to manually download and format the data to use SNLI.")
    
    data_file = f"{task}.zip"
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    
    os.remove(data_file)
    print("\tCompleted!")


def format_mrpc(data_dir, path_to_data):
    """
    Format MRPC dataset from downloaded files.
    
    Args:
        data_dir: Directory to save formatted MRPC data
        path_to_data: Path to directory containing extracted MRPC data
    """
    print("Processing MRPC...")
    mrpc_dir = os.path.join(data_dir, "MRPC")
    
    if not os.path.isdir(mrpc_dir):
        os.mkdir(mrpc_dir)
    
    if path_to_data:
        # Use provided MRPC data
        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
    else:
        # Try to download MRPC data
        try:
            mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
            mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
            urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
            urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
        except urllib.error.HTTPError:
            print("Error downloading MRPC")
            return
    
    # Verify files exist
    assert os.path.isfile(mrpc_train_file), f"Train data not found at {mrpc_train_file}"
    assert os.path.isfile(mrpc_test_file), f"Test data not found at {mrpc_test_file}"

    # Format test file
    with open(mrpc_test_file, encoding='utf-8') as data_fh, \
         open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:
        header = data_fh.readline()
        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
        for idx, row in enumerate(data_fh):
            label, id1, id2, s1, s2 = row.strip().split('\t')
            test_fh.write(f"{idx}\t{id1}\t{id2}\t{s1}\t{s2}\n")

    # Download development IDs
    try:
        urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
    except (KeyError, urllib.error.HTTPError):
        print("\tError downloading standard development IDs for MRPC. "
              "You will need to manually split your data.")
        return

    # Read development IDs
    dev_ids = []
    with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:
        for row in ids_fh:
            dev_ids.append(row.strip().split('\t'))

    # Split train and development sets
    with open(mrpc_train_file, encoding='utf-8') as data_fh, \
         open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh, \
         open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:
        header = data_fh.readline()
        train_fh.write(header)
        dev_fh.write(header)
        
        for row in data_fh:
            label, id1, id2, s1, s2 = row.strip().split('\t')
            if [id1, id2] in dev_ids:
                dev_fh.write(f"{label}\t{id1}\t{id2}\t{s1}\t{s2}\n")
            else:
                train_fh.write(f"{label}\t{id1}\t{id2}\t{s1}\t{s2}\n")
                
    print("\tCompleted!")


def download_diagnostic(data_dir):
    """
    Download GLUE diagnostic dataset.
    
    Args:
        data_dir: Directory to save the diagnostic data
    """
    print("Downloading and extracting diagnostic...")
    
    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
        os.mkdir(os.path.join(data_dir, "diagnostic"))
    
    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
    urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
    print("\tCompleted!")


def get_tasks(task_names):
    """
    Parse task names and return list of tasks to download.
    
    Args:
        task_names: Comma-separated string of task names or 'all'
        
    Returns:
        List of task names to download
    """
    task_names = task_names.split(',')
    
    if "all" in task_names:
        tasks = TASKS
    else:
        tasks = []
        for task_name in task_names:
            assert task_name in TASKS, f"Task {task_name} not found!"
            tasks.append(task_name)
    
    return tasks


def main(arguments):
    """
    Main function to download GLUE datasets.
    
    Args:
        arguments: Command line arguments
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_dir', 
        help='directory to save data to', 
        type=str, 
        default='glue_data'
    )
    parser.add_argument(
        '--tasks', 
        help='tasks to download data for as a comma separated string',
        type=str, 
        default='all'
    )
    parser.add_argument(
        '--path_to_mrpc', 
        help='path to directory containing extracted MRPC data, '
             'msr_paraphrase_train.txt and msr_paraphrase_text.txt',
        type=str, 
        default=''
    )
    
    args = parser.parse_args(arguments)

    # Create data directory if it doesn't exist
    if not os.path.isdir(args.data_dir):
        os.mkdir(args.data_dir)
    
    # Get list of tasks to download
    tasks = get_tasks(args.tasks)

    # Download each task
    for task in tasks:
        if task == 'MRPC':
            format_mrpc(args.data_dir, args.path_to_mrpc)
        elif task == 'diagnostic':
            download_diagnostic(args.data_dir)
        else:
            download_and_extract(task, args.data_dir)


if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))