import datasets
import os
import json
import copy
from ft_datasets.utils import ConcatDataset
from torch.utils.data import Dataset
import torch

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = " \
Generate conclusions from patient data and metrics based on these guidelines: \
\n1. AFib: Focus on heart rates <60 or >130, AFib duration, and longest RR interval. \
\n2. RR Interval: Report >5.0 seconds as critical; document >2.0 seconds. \
\n3. HRV: Mostly automated analysis; limited clinical use except for electrophysiologists. \
\n4. Heart Rate: Normal range 60-100. For critical patients, 80-90 ideal; <40 or >110 concerning. Note fastest, slowest, and average rates with rhythm type. \
\n5. Arrhythmias: Identify and report all non-normal rhythms. \
\n6. Age and Gender: Elderly may have slower rates. Use 60-100 standard for ambulatory ECG, considering clinical context."
SYSTEM_PROMPT = B_SYS + SYSTEM_PROMPT + E_SYS

def get_ecg_dataset(dataset_config, tokenizer, train_dataset_path, max_words=30, for_completion=False, concat=False, split="train"):
    if concat:
        return ConcatDataset(InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, for_completion=for_completion, pad=False, split=split))
    else:
        return InstructionDataset(dataset_config, tokenizer, train_dataset_path, max_words, for_completion=for_completion, pad=True, split=split)

class InstructionDataset(Dataset):
    def __init__(self, dataset_config, tokenizer, train_dataset_path, max_words=30, for_completion=False, pad=True, split='train'):
        self.max_words = max_words
        self.tokenizer = tokenizer
        self.for_completion = for_completion
        self.pad = pad
        
        # Read dataset
        data_file_path = os.path.join(dataset_config.data_path, train_dataset_path)
        self.data = []
        with open(data_file_path, 'r') as file:
            self.json_list = json.load(file)
            
        for json_str in self.json_list:
            if json_str:
                try:
                    record = json_str
                    self.data.append({
                        'metrics': record["metrics"],
                        'interpretation': record['conclusion']
                    })
                except json.JSONDecodeError as e:
                    print(f"Failed to decode JSON: {e} in line: {json_str}")
                    raise

        # Sort data by length of 'input' field in descending order
        # self.data.sort(key=lambda x: len(x['input']), reverse=True)

        # Split the data into train and test sets
        print(f"num of instances in raw dataset: {len(self.data)}")
        if split == 'train':
            self.data = self.data[:1000]
        elif split == 'test':
            self.data = self.data[-10:]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        ann = self.data[index]
        prompt = SYSTEM_PROMPT + B_INST + " " + ann['input'].strip() + " " + E_INST
        if self.for_completion:
            example = prompt
        else:
            example = prompt + " " + ann['output'].strip() + " "

        n_word = len(example.split())
        IGNORE_INDEX = -100
        
        prompt = torch.tensor(
            self.tokenizer.encode(prompt), dtype=torch.int64
        )
        
        example = self.tokenizer.encode(example)
        
        n_token = len(example)
        if n_token >= 1000:
            print(f"len of: example {n_word} | tokenized example {n_token}") 
        
        example.append(self.tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )
        
        if self.pad:
            padding = self.max_words - example.shape[0]
            if padding > 0:
                example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
            elif padding < 0:
                example = example[: self.max_words]
        
        labels = copy.deepcopy(example)
        labels[: len(prompt)] = -1
        example_mask = example.ge(0)
        label_mask = labels.ge(0)
        example[~example_mask] = 0
        labels[~label_mask] = IGNORE_INDEX
        example_mask = example_mask.float()
        label_mask = label_mask.float()

        return {
            "input_ids": example,
            "labels": labels,
            "attention_mask": example_mask,
        }
