from .data_loader import Sample
from typing import List
from .fewshotsampler import FewshotSampler
from .data_loader import NERDataset
import torch
from torch.utils.data import DataLoader

def get_tag_label_mapping(samples: List[Sample]):
    classes = []
    for sample in samples:
        classes += sample.get_tag_class()
    classes = list(set(classes))
    classes = ['O'] + classes
    tag2label = {tag:ind for ind, tag in enumerate(classes)}
    label2tag = {ind:tag for ind, tag in enumerate(classes)}
    return tag2label, label2tag

def prepare_initial_tensor_supportset(trainset, tokenizer, max_length, ignore_idx, tag2label, K=100):
    def collate_fn(data):
        batch = {}
        for k in data[0]:
            batch[k] = torch.cat([d[k] for d in data], 0)
        return batch

    sampler = FewshotSampler(-1, K, 0, trainset, sample_query=False)
    _, support_idx, _ = sampler.__next__()
    support_set = [trainset[i] for i in support_idx]
    support_set = NERDataset(None, tokenizer, max_length, samples=support_set, ignore_label_id=ignore_idx, tag2label=tag2label)
    samples = support_set.tokenized_samples
    return DataLoader(samples, batch_size=8, collate_fn=collate_fn)

def load_data_from_file(filepath):
    samples = []
    with open(filepath, 'r', encoding='utf-8')as f:
        lines = f.readlines()
    samplelines = []
    index = 0
    for line in lines:
        line = line.strip()
        if line:
            samplelines.append(line)
        else:
            sample = Sample(samplelines)
            samples.append(sample)
            samplelines = []
            index += 1
    return samples