from functools import partial
import itertools
import json
import os
from hashlib import sha256
import string
from typing import List, Tuple
import inspect

import torch
from torch.nn.utils.rnn import pad_sequence

from configs import DataArguments, ScriptArguments
from tasks import task_registry

import numpy as np
from datasets import IterableDataset, Dataset, interleave_datasets, concatenate_datasets
from transformers import PreTrainedTokenizer, TrainingArguments

def data_generator(
    op: str,
    shard: List[range] = None,
    seed: int = None,
    train: bool = None, # mainly controls whether we sample from a range specified in the kwargs
    kwargs: dict = None, # kwargs for the data generation function
):
    assert len(shard) == 1, f'Shard should be a list of one range, but got: {shard}'
    rank = os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else 0
    seed = sha256(f'{op} {shard} {seed} {train} {kwargs} {rank}'.encode()).digest()
    seed_int = int.from_bytes(seed, 'big')
    rng = np.random.default_rng(seed_int)
    print(f'Generating data for {op} {shard} {seed_int} {train} {kwargs} {rank}')
    for _ in shard[0]:
        task_output = task_registry[op](rng, **kwargs)
        if len(task_output) == 3:
            prompt, target, loss_mask = task_output
            extra_info = {}
        elif len(task_output) == 4:
            prompt, target, loss_mask, extra_info = task_output

        if loss_mask is None:
            loss_mask = [1] * len(target)

        # prompt = prompt + string.ascii_uppercase[task_id]

        yield {
            'prompt': prompt,
            'target': target,
            'loss_mask': loss_mask,
            **extra_info,
        }

def get_dataset_display_name(op, kwargs):
    kw_args_str = '-'.join([f'{k}={v}' for k, v in sorted(kwargs.items())])
    kw_args_str = kw_args_str.replace(' ', '')
    return f'{op}-{kw_args_str}'

def get_shards(num_total, num_workers=0):
    if num_workers > 0:
        shards = [range(i * round(num_total // num_workers), max(1, (i + 1) * round(num_total // num_workers))) for i in range(num_workers)]
        shards = [sh for sh in shards if len(sh) > 0]
    else:
        shards = [range(round(num_total))]
    return shards

def add_special_tokens(batch, tokenizer: PreTrainedTokenizer, task_id=''):
    bos_token = tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else tokenizer.eos_token
    batch['prompt'] = [bos_token + task_id * 4 + i + task_id * 4 for i in batch['prompt']]
    batch['target'] = [i + tokenizer.eos_token for i in batch['target']]
    batch['loss_mask'] = [i + [1] for i in batch['loss_mask']]
    return batch

def mask_target(target_ids, loss_mask):
    return [t if m == 1 else -100 for t, m in zip(target_ids, loss_mask)]

def tokenization_train(batch, tokenizer: PreTrainedTokenizer, mask_prompt=True):
    batch_new = {}
    prompt_ids = tokenizer(batch['prompt'], padding='do_not_pad', add_special_tokens=False)
    target_ids = tokenizer(batch['target'], padding='do_not_pad', add_special_tokens=False)
    batch_new['input_ids'] = [p + t for p, t in zip(prompt_ids['input_ids'], target_ids['input_ids'])]
    batch_new['labels'] = [([-100] * (len(p)) if mask_prompt else p) + mask_target(t, m) for p, t, m in zip(prompt_ids['input_ids'], target_ids['input_ids'], batch['loss_mask'])]
    batch_new['attention_mask'] = [p + t for p, t in zip(prompt_ids['attention_mask'], target_ids['attention_mask'])]
    return batch_new

def tokenization_eval(batch, tokenizer: PreTrainedTokenizer):
    batch_new = tokenizer(batch['prompt'], padding='do_not_pad', add_special_tokens=False, return_token_type_ids=False)
    batch_new['labels'] = tokenizer(batch['target'], padding='do_not_pad', add_special_tokens=False)['input_ids']
    for k in batch_new.keys():
        batch_new['eval_' + k] = batch_new.pop(k)
    batch_new['eval_loss_mask'] = batch['loss_mask']
    return batch_new

def get_train_dataset(args: ScriptArguments, train_args: TrainingArguments, tokenizer: PreTrainedTokenizer=None, no_sample_from: dict[str, Dataset]=None):    
    for key in no_sample_from:
        no_sample_from[key] = no_sample_from[key].to_dict()
        no_sample_from[key]['prompt'] = set(no_sample_from[key]['prompt']) # convert to set for faster lookup

    def filter_eval(example, data_args: DataArguments):
        key = get_dataset_display_name(data_args.op, data_args.kwargs)
        if key not in no_sample_from:
            return True
        return example['prompt'] not in no_sample_from[key]['prompt']

    ds_list = []
    for opi, (task_id, data_args) in enumerate(args.train_data.items()):
        if args.use_iterable_dataset:
            ds_class = IterableDataset
            kwargs = {}
        else:
            ds_class = Dataset
            kwargs = {'num_proc': args.num_workers, 'cache_dir': args.train_cache_loc}
            if args.num_workers == 0:
                kwargs.pop('num_proc')
            os.makedirs(args.train_cache_loc, exist_ok=True)

        if data_args.dataset_path is not None:
            ds = Dataset.load_from_disk(data_args.dataset_path)
            if args.use_iterable_dataset:
                ds = ds.to_iterable_dataset()
            # if 'loss_mask' not in ds.column_names:
            #     ds = ds.map(lambda x: {**x, 'loss_mask': [1] * len(x['target'])})
            if 'labels' not in ds.column_names:
                # Copy input_ids to a new column named 'labels'
                ds = ds.map(lambda x: {**x, 'labels': x['input_ids']})
            ds = ds.filter(lambda x: len(x['input_ids']) <= args.max_position_embeddings)
            ds_list.append(ds)
            continue

        ds = ds_class.from_generator(
            data_generator,
            gen_kwargs={
                'shard': get_shards(int(args.num_train), args.num_workers),
                'train': True,
                'op': data_args.op,
                'kwargs': data_args.kwargs,
                'seed': train_args.seed,
            },
            split='train',
            **kwargs,
        )
        # if isinstance(ds, Dataset):
        #     ds.save_to_disk(os.path.join(args.train_cache_loc, f'{get_dataset_display_name(data_args.op, data_args.kwargs)}-{int(args.num_train * data_args.frac)}'))
        kwargs.pop('cache_dir', None)
        ds.info.description = json.dumps(data_args.kwargs)
        # ds = ds.filter(partial(filter_eval, data_args=data_args), **kwargs)
        if tokenizer is not None: # return raw data if tokenizer is not provided
            ds = ds.map(add_special_tokens, fn_kwargs={'tokenizer': tokenizer, 'task_id': task_id}, batched=True, batch_size=1024, **kwargs)
            if args.train_algo == 'SFT':
                ds = ds.map(tokenization_train, fn_kwargs={'tokenizer': tokenizer, 'mask_prompt': args.mask_prompt}, batched=True, batch_size=1024, **kwargs)
                ds = ds.select_columns(['input_ids', 'labels', 'attention_mask', 'loss_mask'])
            else:
                ds = ds.map(tokenization_eval, fn_kwargs={'tokenizer': tokenizer}, batched=True, batch_size=1024, **kwargs)
                ds = ds.select_columns(['eval_input_ids', 'eval_labels', 'eval_attention_mask', 'eval_loss_mask'])
        ds_list.append(ds)

    if len(ds_list) > 1:
        fracs = [data_args.frac for data_args in args.train_data.values()]
        init_probs = [frac / sum(fracs) for frac in fracs]

        init_probs_uniform = [1 / len(args.train_data)] * len(args.train_data)
        is_close_to_uniform = np.allclose(init_probs, init_probs_uniform)
        if is_close_to_uniform:
            init_probs = None

        ds: Dataset = interleave_datasets(ds_list, probabilities=init_probs, seed=train_args.seed, stopping_strategy='all_exhausted')
    else:
        ds = ds_list[0]

    print('----------- Examples from train: -------------')
    for example in ds.take(3):
        print(example)
        inp = [max(t, 0) for t in example['input_ids']] if args.train_algo == 'SFT' else [max(t, 0) for t in example['eval_input_ids']]
        print(tokenizer.decode(inp))
        if 'labels' in example:
                lab = [max(t, 0) for t in example['labels']] if args.train_algo == 'SFT' else [max(t, 0) for t in example['eval_labels']]
                print(tokenizer.decode(lab))        
        print()

    return ds

def divide_range(start, stop, step=1):
    return [[i, min(i + step, stop)] for i in range(start, stop, step)]

def get_eval_kwargs(data_args: DataArguments):
    key_iters = []
    non_tied_keys = list(data_args.kwargs.keys())
    k_iter_len = []
    for ks in data_args.tied_keys:
        k_iter = zip(*[divide_range(*data_args.kwargs[k]) for k in ks])
        k_iter_len.append(len(ks))
        key_iters.append(k_iter)
        non_tied_keys = [k for k in non_tied_keys if k not in ks]
    for k in non_tied_keys:
        if k in data_args.eval_keys:
            k_iter = list(zip(divide_range(*data_args.kwargs[k])))
            k_iter_len.append(len(k_iter))
            key_iters.append(k_iter)
        else:
            k_iter = [(data_args.kwargs[k], )]
            key_iters.append(k_iter)
    assert all(l == k_iter_len[0] for l in k_iter_len), 'All keys should have the same number of values'
    
    eval_kwargs = []
    value_comb_list = list(itertools.product(*key_iters))
    for val_comb in value_comb_list:
        all_keys = sum(data_args.tied_keys, []) + non_tied_keys
        all_vals = sum(val_comb, ())
        eval_kwargs.append({k: v for k, v in zip(all_keys, all_vals)})
    return eval_kwargs

def get_eval_dataset(args: ScriptArguments, tokenizer: PreTrainedTokenizer):
    ds_list = {}
    unmapped_ds_list = {}

    for opi, (task_id, data_args) in enumerate(args.eval_data.items()):
        if data_args.dataset_path is not None:
            ds = Dataset.load_from_disk(data_args.dataset_path).take(args.num_eval)
            # if 'loss_mask' not in ds.column_names:
            #     ds = ds.map(lambda x: {**x, 'loss_mask': [1] * len(x['target'])})
            if args.use_iterable_dataset:
                ds = ds.to_iterable_dataset()
            if 'labels' not in ds.column_names:
                ds = ds.map(lambda x: {**x, 'labels': x['input_ids']})
            ds = ds.filter(lambda x: len(x['input_ids']) <= args.max_position_embeddings)
            ds_list['eval'] = ds
            continue

        for kwargs in get_eval_kwargs(data_args):
            os.makedirs(args.eval_cache_loc, exist_ok=True)
            ds0 = Dataset.from_generator(
                data_generator,
                gen_kwargs={
                    'shard': get_shards(int(args.num_eval)),
                    'train': False, 
                    'op': data_args.op,
                    'kwargs': kwargs,
                    'seed': None
                },
                num_proc=None,
                keep_in_memory=True,
                split='eval',
                cache_dir=args.eval_cache_loc,
                dataset_name=sha256(inspect.getsource(task_registry[data_args.op]).encode()).hexdigest()
            )
            ds0.info.description = json.dumps(kwargs)
            ds0.cleanup_cache_files()
            ds = ds0.map(add_special_tokens, fn_kwargs={'tokenizer': tokenizer, 'task_id': task_id}, batched=True, batch_size=1024)
            ds = ds.map(tokenization_eval, fn_kwargs={'tokenizer': tokenizer}, batched=True, batch_size=args.num_eval)
            ds = ds.select_columns(['eval_input_ids', 'eval_labels', 'eval_attention_mask', 'eval_loss_mask'])
            # ds = ds0
            key = get_dataset_display_name(data_args.op, kwargs)
            ds_list[key] = ds
            unmapped_ds_list[key] = ds0

    print(f'Eval datasets: {ds_list.keys()}')
    print('----------- Examples from eval: -------------')
    for key, ds in ds_list.items():
        print(f'Examples from {key}:')
        for example in ds.take(3):
            print(example)

    return ds_list, unmapped_ds_list

class PromptAnswerDataCollator():
    left_pad_list = ['eval_input_ids', 'eval_attention_mask']
    ignore_list = ['prompt', 'target', 'task_id']

    def __init__(self, tokenizer, train_pad_to=None, eval_pad_to=None):
        self.pad_token_id = tokenizer.pad_token_id
        self.label_pad_token_id = tokenizer.pad_token_id

        self.train_pad_to = train_pad_to
        self.eval_pad_to = eval_pad_to

    def __call__(self, features):
        features = {
            key: [example[key] for example in features] for key in features[0].keys()
        }

        padded_batch = {}
        for k, feat in features.items():
            if k in self.ignore_list:
                padded_batch[k] = feat
                continue

            if k in self.left_pad_list:
                to_pad = [torch.LongTensor(ex[::-1]) for ex in feat]
            else:
                to_pad = [torch.LongTensor(ex) for ex in feat]
            
            if k.endswith("input_ids") or k.endswith('eval_labels'):
                if self.pad_token_id is None:
                    raise ValueError(
                        "Padding is enabled, but the tokenizer is not configured with a padding token."
                        " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
                        " before calling the trainer."
                    )
                padding_value = self.pad_token_id
            elif k.endswith("labels"):
                padding_value = self.label_pad_token_id
            elif k.endswith("attention_mask"):
                padding_value = 0
            elif k.endswith("loss_mask"):
                padding_value = 1
            else:
                raise ValueError(f"Unexpected key in batch '{k}'")
            
            # remove the eval_ prefix to conform to model input names
            if 'eval_' in k:
                is_train = False
                input_k = k.replace('eval_', '')
            else:
                is_train = True
                input_k = k

            padded_batch[input_k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)

            pad_to = self.train_pad_to if is_train else self.eval_pad_to
            if pad_to is not None:
                if padded_batch[input_k].shape[1] <= pad_to:
                    padded_batch[input_k] = torch.nn.functional.pad(padded_batch[input_k], (0, pad_to - padded_batch[input_k].shape[1]), value=padding_value)
                else:
                    raise ValueError(f"Cannot pad {k} to max_length {pad_to} because it is already longer than that ({padded_batch[input_k].shape[1]})")

            if k in self.left_pad_list:
                padded_batch[input_k] = padded_batch[input_k].flip(dims=[1])

        return padded_batch
