import os, socket
import numpy as np
import random
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import lightning as L
from functools import partial

from datasets import load_dataset


def collator(sample_list, tokenizer):
    inputs = pad_sequence([torch.LongTensor(s[:-1]) for s in sample_list], batch_first=True,
                          padding_value=tokenizer.pad_token_id)
    targets = pad_sequence([torch.LongTensor(s[1:]) for s in sample_list], batch_first=True,
                           padding_value=tokenizer.pad_token_id)
    return inputs, targets

def load_gsm8k(tokenizer, cache_dir, max_seq_length):

    if not os.path.exists(os.path.join(cache_dir, "gsm8k.pt")):
        dataset = load_dataset("gsm8k", "main", cache_dir=cache_dir)['train']
        dataset = [
            tokenizer.encode(
                f" ".join([
                s['question'],
                f"Answer:",
                s['answer']
                ])
            )
            for s in dataset
        ]
        dataset = [s for s in dataset if 1 < len(s) and len(s) < max_seq_length]
        torch.save({"dataset": dataset}, os.path.join(cache_dir, "gsm8k.pt"))
    else:
        data = torch.load(os.path.join(cache_dir, "gsm8k.pt"))
        dataset = data["dataset"]

    return dataset

def get_gsm8k(tokenizer, max_seq_length, val_split, effective_batch_size, cache_dir):

    dataset = load_gsm8k(tokenizer, cache_dir=cache_dir, max_seq_length=max_seq_length)

    print("GSM8K dataset samples:", len(dataset))

    random.shuffle(dataset)
    train_samples = dataset[:int(len(dataset) * (1 - val_split))]
    val_samples = dataset[int(len(dataset) * (1 - val_split)):]

    collator_tok = partial(collator, tokenizer=tokenizer)

    train_loader = DataLoader(
        train_samples,
        batch_size=effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_samples,
        batch_size=effective_batch_size,
        collate_fn=collator_tok,
        num_workers=0,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
    )

    print("train samples:", len(train_samples))
    print("val samples:", len(val_samples))

    return train_loader, val_loader