import os, zipfile
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForMultipleChoice,
    Trainer,
    TrainingArguments,
    default_data_collator,
    DataCollatorForLanguageModeling
)
import numpy as np

from safetensors.torch import load_file
from datasets import Dataset
import torch
from datasets import load_dataset
import json

import os, json
import pandas as pd
from datasets import Dataset, load_dataset

def _read_table_any(path: str) -> pd.DataFrame:
    ext = os.path.splitext(path)[1].lower()
    if ext == ".tsv":
        return pd.read_csv(path, sep="\t", dtype=str, keep_default_na=False)
    if ext == ".csv":
        return pd.read_csv(path, sep=",", dtype=str, keep_default_na=False)
    if ext == ".jsonl":
        return pd.read_json(path, lines=True)
    if ext == ".json":
        df = pd.read_json(path)
        return df if isinstance(df, pd.DataFrame) else pd.json_normalize(df)
    if ext in [".parquet", ".pq"]:
        return pd.read_parquet(path)
    raise ValueError(f"Unsupported file format: {path}")

def _tolist(x):  # Series/ndarray/iterable -> list
    return x.tolist() if hasattr(x, "tolist") else list(x)

def _tok_single(tokenizer, text, max_length):
    enc = tokenizer(text, padding="max_length", truncation=True, max_length=max_length)
    return enc["input_ids"], enc["attention_mask"]

def _tok_pair(tokenizer, t1, t2, max_length):
    enc = tokenizer(t1, t2, padding="max_length", truncation=True, max_length=max_length)
    return enc["input_ids"], enc["attention_mask"]



def convert_glue_sst2(train_path, dev_path, tokenizer, max_length=256):
   
    def proc(df):
        sent = _tolist(df["sentence"])
        labels = [int(float(x)) for x in _tolist(df["label"])]
        ids, masks = zip(*[_tok_single(tokenizer, s, max_length) for s in sent])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_cola(train_path, dev_path, tokenizer, max_length=256):
    
    def proc(df):
        sent = _tolist(df["sentence"])
        labels = [int(float(x)) for x in _tolist(df["label"])]
        ids, masks = zip(*[_tok_single(tokenizer, s, max_length) for s in sent])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_mrpc(train_path, dev_path, tokenizer, max_length=256):
    
    def proc(df):
        s1, s2 = _tolist(df["sentence1"]), _tolist(df["sentence2"])
        labels = [int(float(x)) for x in _tolist(df["label"])]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(s1, s2)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_qqp(train_path, dev_path, tokenizer, max_length=256):
   
    def proc(df):
        if "question1" in df and "question2" in df:
            q1, q2 = _tolist(df["question1"]), _tolist(df["question2"])
        else:  
            q1, q2 = _tolist(df.iloc[:, 0]), _tolist(df.iloc[:, 1])
        if "label" in df:
            rawy = _tolist(df["label"])
        elif "is_duplicate" in df:
            rawy = _tolist(df["is_duplicate"])
        else:
            raise KeyError("QQP file missing 'label' or 'is_duplicate'")
        labels = [int(float(x)) for x in rawy]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(q1, q2)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_mnli(train_path, dev_path, tokenizer, max_length=256):
    
    lmap = {"entailment": 0, "neutral": 1, "contradiction": 2}
    def to_label(v):
        s = str(v).strip().lower()
        return lmap[s] if s in lmap else int(float(v))
    def proc(df):
        prem, hyp = _tolist(df["premise"]), _tolist(df["hypothesis"])
        labels = [to_label(v) for v in _tolist(df["label"])]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(prem, hyp)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_qnli(train_path, dev_path, tokenizer, max_length=256):
    
    lmap = {"entailment": 1, "not_entailment": 0}
    def to_label(v):
        s = str(v).strip().lower()
        return lmap[s] if s in lmap else int(float(v))
    def proc(df):
        q, s = _tolist(df["question"]), _tolist(df["sentence"])
        labels = [to_label(v) for v in _tolist(df["label"])]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(q, s)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_rte(train_path, dev_path, tokenizer, max_length=256):
    
    lmap = {"entailment": 1, "not_entailment": 0}
    def to_label(v):
        s = str(v).strip().lower()
        return lmap[s] if s in lmap else int(float(v))
    def proc(df):
        s1, s2 = _tolist(df["sentence1"]), _tolist(df["sentence2"])
        labels = [to_label(v) for v in _tolist(df["label"])]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(s1, s2)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_wnli(train_path, dev_path, tokenizer, max_length=256):
    
    def proc(df):
        s1, s2 = _tolist(df["sentence1"]), _tolist(df["sentence2"])
        labels = [int(float(x)) for x in _tolist(df["label"])]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(s1, s2)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": labels})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))


def convert_glue_stsb_reg(train_path, dev_path, tokenizer, max_length=256):
    
    def get_score_col(df):
        if "score" in df: return _tolist(df["score"])
        if "label" in df: return _tolist(df["label"])
        raise KeyError("STS-B file missing 'score' or 'label'")
    def proc(df):
        s1, s2 = _tolist(df["sentence1"]), _tolist(df["sentence2"])
        scores = [float(x) for x in get_score_col(df)]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(s1, s2)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": scores})  # float
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))

def convert_glue_stsb_cls(train_path, dev_path, tokenizer, max_length=256):
    
    def get_score_col(df):
        if "score" in df: return _tolist(df["score"])
        if "label" in df: return _tolist(df["label"])
        raise KeyError("STS-B file missing 'score' or 'label'")
    def proc(df):
        s1, s2 = _tolist(df["sentence1"]), _tolist(df["sentence2"])
        y = [max(0, min(5, int(round(float(v))))) for v in get_score_col(df)]
        ids, masks = zip(*[_tok_pair(tokenizer, a, b, max_length) for a, b in zip(s1, s2)])
        return Dataset.from_dict({"input_ids": list(ids), "attention_mask": list(masks), "label": y})
    return proc(_read_table_any(train_path)), proc(_read_table_any(dev_path))



def preprocess_arc(example, tokenizer, max_length=128):
    label_map = {choice["label"]: i for i, choice in enumerate(example["question"]["choices"])}
    inputs = []
    labels = []
    for i, choice in enumerate(example["question"]["choices"]):
        text = example["question"]["stem"] + " " + choice["text"]
        inputs.append(text)
        labels.append(i)
    correct_choice = example["answerKey"]
    label = label_map[correct_choice]
    tokenized = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_length)
    tokenized["label"] = label
    return tokenized


def preprocess_piqa(example, label, tokenizer, max_length=128):
    inputs = [
        example["goal"] + " " + example["sol1"],
        example["goal"] + " " + example["sol2"]
    ]
    encoding = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_length)
    encoding["label"] = label
    return encoding


def preprocess_winogrande(example, tokenizer, max_length=128):
    sentence = example["sentence"]
    option1 = example["option1"]
    option2 = example["option2"]
    inputs = [
        sentence.replace("_", option1),
        sentence.replace("_", option2)
    ]
    encoding = tokenizer(inputs, padding="max_length", truncation=True, max_length=max_length)
    label = 0 if example["answer"] == "1" else 1
    encoding["label"] = label
    return encoding

def preprocess_hellaswag(example, tokenizer, max_length=128):
    context = example['ctx']
    if example.get('ctx_b', ''):
        context += " " + example['ctx_b']
    endings = example['endings']
    inputs = [context + " " + ending for ending in endings]
    encoding = tokenizer(
        inputs,
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )
    return {
        "input_ids": encoding["input_ids"],        # list of list
        "attention_mask": encoding["attention_mask"],
        "labels": example['label']                 # int
    }


def has_four_choices(example):
    return len(example["question"]["choices"]) == 4

def load_data(jsonl_file, labels_file):
    with open(jsonl_file, 'r') as f:
        data = [json.loads(line) for line in f]
    
    with open(labels_file, 'r') as f:
        labels = [int(line.strip()) for line in f]  
    
    return data, labels



def load_and_process_dataset(dataset_name, tokenizer, max_length=128):
    if dataset_name == 'arc_easy':
        data_files = {
            "train": "datasets/ARC-V1-Feb2018/ARC-Easy/ARC-Easy-Train.jsonl",
            "test": "datasets/ARC-V1-Feb2018/ARC-Easy/ARC-Easy-Test.jsonl"
        }
        dataset = load_dataset("json", data_files=data_files)
        dataset = dataset.filter(lambda example: len(example["question"]["choices"]) == 4)
        dataset = dataset.map(lambda example: preprocess_arc(example, tokenizer, max_length), remove_columns=dataset["train"].column_names)
        train_dataset = dataset["train"]
        valid_dataset = dataset["test"]

    elif dataset_name == "arc_challenge":
        data_files = {
            "train": "datasets/ARC-V1-Feb2018/ARC-Challenge/ARC-Challenge-Train.jsonl",
            "test": "datasets/ARC-V1-Feb2018/ARC-Challenge/ARC-Challenge-Test.jsonl"
        }
        dataset = load_dataset("json", data_files=data_files)
        dataset = dataset.filter(lambda example: len(example["question"]["choices"]) == 4)
        dataset = dataset.map(lambda example: preprocess_arc(example, tokenizer, max_length), remove_columns=dataset["train"].column_names)
        train_dataset = dataset["train"]
        valid_dataset = dataset["test"]

    elif dataset_name == 'piqa':
        train_data, train_labels = load_data('datasets/PIQA/train.jsonl', 'PIQA/train-labels.lst')
        valid_data, valid_labels = load_data('/datasets/PIQA/valid.jsonl', 'datasets/PIQA/valid-labels.lst')
        
        train_dataset = Dataset.from_dict({
            "input_ids": [preprocess_piqa(example, label, tokenizer, max_length)["input_ids"] for example, label in zip(train_data, train_labels)],
            "attention_mask": [preprocess_piqa(example, label, tokenizer, max_length)["attention_mask"] for example, label in zip(train_data, train_labels)],
            "label": train_labels
        })


        valid_dataset = Dataset.from_dict({
            "input_ids": [preprocess_piqa(example, label, tokenizer, max_length)["input_ids"] for example, label in zip(valid_data, valid_labels)],
            "attention_mask": [preprocess_piqa(example, label, tokenizer, max_length)["attention_mask"] for example, label in zip(valid_data, valid_labels)],
            "label": valid_labels
        })

    elif dataset_name == 'winogrande':
        data_files = {
            "train": "datasets/winogrande/winogrande_1.1/train_debiased.jsonl",
            "test": "datasets/winogrande/winogrande_1.1/dev.jsonl"
        }
        dataset = load_dataset("json", data_files=data_files)
        dataset = dataset.map(lambda example: preprocess_winogrande(example, tokenizer, max_length), remove_columns=dataset["train"].column_names)

        train_dataset = dataset["train"]
        valid_dataset = dataset["test"]
        
    elif dataset_name == "hellaswag":
        data_files = {
            "train": "datasets/hellaswag/hellaswag_train.jsonl",
            "test": "datasets/hellaswag/hellaswag_val.jsonl"
        }
        dataset = load_dataset("json", data_files=data_files)
        
        dataset = dataset.map(lambda example: preprocess_hellaswag(example, tokenizer, max_length), remove_columns=dataset["train"].column_names)
        train_dataset = dataset["train"]
        valid_dataset = dataset["test"]

    elif dataset_name == "openbqa":
        data_files = {
            "train": "datasets/openbqa/train.jsonl",
            "test": "datasets/openbqa/test.jsonl"
        }
        dataset = load_dataset("json", data_files=data_files)
        dataset = dataset.filter(lambda example: len(example["question"]["choices"]) == 4)
        dataset = dataset.map(lambda example: preprocess_arc(example, tokenizer, max_length), remove_columns=dataset["train"].column_names)
        train_dataset = dataset["train"]
        valid_dataset = dataset["test"]

    elif dataset_name == "glue_sst2":
        train_file = "datasets/glue/sst2/train-00000-of-00001.parquet"
        dev_file   = "datasets/glue/sst2/test-00000-of-00001.parquet"    
        train_dataset, valid_dataset = convert_glue_sst2(train_file, dev_file, tokenizer, max_length)

    elif dataset_name == "glue_cola":
        train_file = "glue/cola/train-00000-of-00001.parquet"
        dev_file   = "glue/cola/test-00000-of-00001.parquet"    
        train_dataset, valid_dataset = convert_glue_cola(train_file, dev_file, tokenizer, max_length)

    elif dataset_name == "glue_mrpc":
        train_file = "datasets/glue/mrpc/train-00000-of-00001.parquet"
        dev_file   = "datasets/glue/mrpc/test-00000-of-00001.parquet"   
        train_dataset, valid_dataset = convert_glue_mrpc(train_file, dev_file, tokenizer, max_length)

    elif dataset_name == "glue_qqp":
        train_file = "datasets/glue/qqp/train-00000-of-00001.parquet"
        dev_file   = "datasets/glue/qqp/test-00000-of-00001.parquet"    
        train_dataset, valid_dataset = convert_glue_qqp(train_file, dev_file, tokenizer, max_length)

    elif dataset_name == "glue_mnli":
        train_file = "datasets/glue/mnli/train-00000-of-00001.parquet"
        dev_file   = "datasets/glue/mnli/validation_matched-00000-of-00001.parquet"  
        train_dataset, valid_dataset = convert_glue_mnli(train_file, dev_file, tokenizer, max_length)

    elif dataset_name == "glue_qnli":
        train_file = "datasets/glue/qnli/train-00000-of-00001.parquet"
        dev_file   = "datasets/glue/qnli/test-00000-of-00001.parquet"  
        train_dataset, valid_dataset = convert_glue_qnli(train_file, dev_file, tokenizer, max_length)

    elif dataset_name == "glue_rte":
        train_file = "datasets/glue/rte/train-00000-of-00001.parquet"
        dev_file   = "datasets/glue/rte/validation-00000-of-00001.parquet"  
        train_dataset, valid_dataset = convert_glue_rte(train_file, dev_file, tokenizer, max_length)

    elif dataset_name == "glue_wnli":
        train_file = "datasets/glue/wnli/train-00000-of-00001.parquet"
        dev_file   = "datasets/glue/wnli/test-00000-of-00001.parquet"  
        train_dataset, valid_dataset = convert_glue_wnli(train_file, dev_file, tokenizer, max_length)

    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    return train_dataset, valid_dataset
