import pandas as pd
import numpy as np
import torch
from params import *
class Feature:
    def __init__(self, id, input_ids, input_mask, segment_ids):
        self.id = id
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
    
def gen(n_test, dataset = "waterbirds", name = "test"):
    np.random.seed(42)
    dataset = dataset.lower()
    dir = f"dataset/{dataset}/"
    comb_list = [(0,0), (0,1), (1,0), (1,1)]
    if dataset == "waterbirds":
        n_train = [3000, 1400, 400, 900]
    elif dataset == "celeba":
        n_train = [8000, 7500, 3000, 500]
    elif dataset == "metashift":
        n_train = [400, 260, 100, 420]
    elif dataset == "multinli":
        n_train = [5000, 1000, 6000, 150, 6000, 200]
        comb_list = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)]
    elif dataset == "chexpert":
        n_train = [6889,4491,539,517,4485,3122,717,463,72,67,517,394]
        comb_list = [(y, a) for y in range(2) for a in range(6)]
    else:
        raise NotImplementedError
    
    if dataset == "chexpert":
        dir = "dataset/chexpert/subpop_bench_meta"
        metadata = pd.read_csv(dir + "/metadata_no_finding.csv")
    else:
        metadata = pd.read_csv(dir + f"metadata_{dataset}.csv")

    metadata["split"] = 1
    for i, comb in enumerate(comb_list):
        sub = metadata[(metadata["y"] == comb[0]) & (metadata["a"] == comb[1])]
        random_indices = np.random.choice(sub.shape[0], n_test[i] + n_train[i], replace = False)
        # print(random_indices[:5])
        chosen_train = sub.iloc[random_indices[:n_train[i]]]
        chosen_test = sub.iloc[random_indices[n_train[i]:]]
        metadata.loc[chosen_train.index, 'split'] = 0
        metadata.loc[chosen_test.index, 'split'] = 2
    print((metadata["split"] == 0).sum())
    print((metadata["split"] == 1).sum())
    print((metadata["split"] == 2).sum())
    metadata.to_csv(f"{dir}/{name}.csv", index = False)

def gen_snli(n_test, name = "0"):
    np.random.seed(42)
    comb_list = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)]
    dir = "dataset/snli/"
    metadata = pd.read_csv(dir + "metadata.csv")
    train = metadata[metadata["split"] == 0]
    valid = metadata[metadata["split"] == 1]
    tv = pd.concat([train, valid], axis = 0)
    for i, comb in enumerate(comb_list):
        sub = metadata[(metadata["y"] == comb[0]) & (metadata["a"] == comb[1]) &(metadata["split"]==2)]
        random_indices = np.random.choice(sub.shape[0], n_test[i], replace = False)
        chosen_test = sub.iloc[random_indices]
        tv = pd.concat([tv, chosen_test], axis = 0)
    tv.to_csv(f"{dir}/{name}.csv", index = False)

def gen_llm(n_test, dataset = "multinli", name = "test"):
    np.random.seed(42)
    dataset = dataset.lower()
    dir = f"dataset/{dataset}/"
    if dataset == "multinli":
        comb_list = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)]
    else:
        raise NotImplementedError
    metadata = pd.read_csv(dir + f"metadata_{dataset}.csv")

    train = torch.load("total")
    tr_id_list = [item.id for item in train]
    other = [id for id in range(len(metadata)) if id not in tr_id_list]
    metadata["split"] = 1
    metadata.loc[tr_id_list, "split"] = 0
    for i, comb in enumerate(comb_list):
        te_val = metadata.iloc[other]
        sub = te_val[(te_val["y"] == comb[0]) & (te_val["a"] == comb[1])]
        random_indices = np.random.choice(sub.shape[0], n_test[i], replace = False)
        chosen_test = sub.iloc[random_indices]
        metadata.loc[chosen_test.index, 'split'] = 2
    for i in range(3):
        print((metadata["split"] == i).sum())
    metadata.to_csv(f"{dir}/{name}.csv", index = False)

if __name__ == "__main__":
    dataset = "snli"

    if dataset == "multinli":
        total_n_test = n6_test
    elif dataset == "chexpert":
        total_n_test = n12_test
    elif dataset == "snli":
        total_n_test = n6_test2
    else:
        total_n_test = n4_test

    for i, n_test in enumerate(total_n_test):
        if dataset == "metashift":
            n_test = [x//2 for x in n_test]
        print(n_test)
        if dataset == "multinli":
            gen_llm(n_test = n_test, dataset = "multinli", name = i)
        elif dataset == "snli":
            gen_snli(n_test = n_test, name = i)
        else:
            gen(n_test = n_test, dataset = dataset, name = i)
        

