import os, socket
import numpy as np
import random
import torch
import torch.nn.functional as F
import tqdm
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import lightning as L
from functools import partial

from datasets import load_dataset




def collator(sample_list, tokenizer):
    inputs = pad_sequence([torch.LongTensor(s[:-1]) for s in sample_list], batch_first=True,
                          padding_value=tokenizer.pad_token_id)
    targets = pad_sequence([torch.LongTensor(s[1:]) for s in sample_list], batch_first=True,
                           padding_value=tokenizer.pad_token_id)
    return inputs, targets

def load_pubmed(tokenizer, model_name, cache_dir, max_seq_length, artificial_numb):
    if not os.path.exists(os.path.join(cache_dir, f"bigbio_pubmed_qa_{model_name}_len{max_seq_length}_{artificial_numb}.pt")):
        artificial_datasets = load_dataset("bigbio/pubmed_qa", "pubmed_qa_artificial_source", cache_dir=cache_dir)['train']
        # artificial_datasets = load_dataset("qiaojin/PubMedQA", 'pqa_artificial', cache_dir=cache_dir)['train']
        print("load complete")
        artificial_datasets = [tokenizer.encode(
            f" ".join(s['CONTEXTS'] + [s['QUESTION'], s['LONG_ANSWER'], f"Answer: {s['final_decision']}"]))
            # f" ".join(s['context']['contexts'] + [s['question'], s['long_answer'], f"Answer: {s['final_decision']}"]))
            for s in tqdm.tqdm(artificial_datasets)]
        artificial_datasets = [s for s in  tqdm.tqdm(artificial_datasets) if 1 < len(s) and len(s) < max_seq_length]
        # artificial_datasets = [s for s in artificial_datasets if 350 < len(s) and len(s) < max_seq_length] # TODO remove

        artificial_datasets = artificial_datasets[:artificial_numb]

        # unlabeled_datasets = load_dataset("pubmed_qa", 'pqa_unlabeled', cache_dir=cache_dir)['train']
        # unlabeled_datasets = [tokenizer.encode(f" ".join(s['context']['contexts'] + [s['question'], s['long_answer']]))
        #                       for s in unlabeled_datasets]
        # # unlabeled_datasets = [s for s in unlabeled_datasets if 1 < len(s)]
        # unlabeled_datasets = []
        #
        # labeled_datasets = load_dataset("pubmed_qa", 'pqa_labeled', cache_dir=cache_dir)['train']
        # labeled_samples = [tokenizer.encode(
        #     f" ".join(s['context']['contexts'] + [s['question'], s['long_answer'], f"Answer: {s['final_decision']}"]))
        #     for s in labeled_datasets]
        # labeled_samples = [s for s in labeled_samples if 1 < len(s) and len(s) < max_seq_length]

        # torch.save({"artificial": artificial_datasets, "unlabeled": unlabeled_datasets, "labeled": labeled_samples},
        #                os.path.join(cache_dir, f"pubmed_qa_{model_name}_len{max_seq_length}_.pt"))
        torch.save(artificial_datasets, os.path.join(cache_dir, f"bigbio_pubmed_qa_{model_name}_len{max_seq_length}_{artificial_numb}.pt"))
    else:
        artificial_datasets = torch.load(os.path.join(cache_dir, f"bigbio_pubmed_qa_{model_name}_len{max_seq_length}_{artificial_numb}.pt"))
        # artificial_datasets = data["artificial"]
        # unlabeled_datasets = data["unlabeled"]
        # labeled_samples = data["labeled"]


    return artificial_datasets #  unlabeled_datasets, labeled_samples

def get_pubmedqa(tokenizer, config):


    random.seed(config.data_seed)

    model_name = config.model_name.split("/")[-1]




    # artificial_datasets, unlabeled_datasets, labeled_samples = load_pubmed(tokenizer,model_name=model_name, cache_dir=config.cache_dir, max_seq_length=config.max_seq_length)
    samples = load_pubmed(tokenizer,model_name=model_name, cache_dir=config.cache_dir,
                          max_seq_length=config.max_seq_length, artificial_numb=config.pubmedqa.artificial_numb)

    print("artificial samples:", len(samples))
    # print("unlabeled samples:", len(unlabeled_datasets))
    # print("labeled samples:", len(labeled_samples))

    # samples = artificial_datasets[:config.pubmedqa.artificial_numb]
    # samples = labeled_samples * config.pubmedqa.labeled_oversample + artificial_datasets[:config.pubmedqa.artificial_numb] + unlabeled_datasets[:config.pubmedqa.unlabeled_numb]

    # samples = [s for s in samples if 1 < len(s) and len(s) < config.max_seq_length]

    print("total samples:", len(samples))

    random.shuffle(samples)
    train_samples = samples[:int(len(samples) * (1 - config.val_split))]
    val_samples = samples[int(len(samples) * (1 - config.val_split)):]

    collator_tok = partial(collator, tokenizer=tokenizer)

    train_loader = DataLoader(
        train_samples,
        batch_size=config.effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_samples,
        batch_size=config.effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
    )


    print("train samples:", len(train_samples))
    print("val samples:", len(val_samples))


    return train_loader, val_loader


if __name__ == "__main__":

    from transformers import AutoTokenizer, AutoModelForCausalLM


    # train_loader, val_loader = get_pubmedqa(tokenizer, config=config)

    cache_dir = '/home/joerg/workspace/python/github/ICML2024_experiments/cache'

    model_name = "facebook/opt-125m"

    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

    model_name = model_name.split("/")[-1]
    samples = load_pubmed(tokenizer, model_name, cache_dir, max_seq_length=400, artificial_numb=100000)

    print("tokenized samples:", len(samples))

    total_tokens = 0
    for s in samples:
        total_tokens += len(s)

    print("total tokens:", total_tokens)