import pickle
from tqdm import tqdm
from collections import defaultdict
from .base_dataset import DatasetWithPerturbation
from typing import Dict, List
from copy import deepcopy
from .qa_datasets import CoQA, SQUADv2

import numpy as np
from numpy.random import RandomState


def shuffle_words_in_text(
    text,
    rng: RandomState,
    strength=0.1,
):
    # split the text into words
    text_words = text.split(' ')

    if round(len(text_words)*strength)<1:
        return text

    # now generate the shuffle guide array
    guide = rng.randint(0, len(text_words), size=2*round(len(text_words)*strength)).reshape(-1,2).tolist()

    for m_from, m_to in guide:
        word = text_words.pop(m_from)
        text_words.insert(m_to, word)

    return ' '.join(text_words)


def perturbation_router(text:str, ptype:str, strength:float, **kwargs):
    if ptype=='shuffle':
        return shuffle_words_in_text(text=text, strength=strength, **kwargs)
    else:
        raise NotImplementedError(f"Perturbation type {ptype} not implemented!")



class CoQAPerturb(DatasetWithPerturbation):
    def __init__(
        self,
        perturb_strength=[0.,0.1],
        perturb_type='shuffle', # shuffle, obscure
        perturb_dynamic=False,
        slim=1024,
        **coqa_kwargs,
    ):
        self.coqa_data = CoQA(**coqa_kwargs)
        
        self.perturb_type=perturb_type
        self.perturb_dynamic=perturb_dynamic
        self.perturb_strength=perturb_strength

        if not perturb_dynamic:
            # build up the static dataset from CoQA
            self.data = []
            self.group_idxs_map = defaultdict(list)
            rng = RandomState(seed=42)
            # self.data_perturb_group = []
            # self.data_perturb_strengths = []
            for sample_id, coqa_dset_item in enumerate(self.coqa_data):
                if sample_id >= slim:
                    break
                
                for p_id, strength in enumerate(perturb_strength):
                    coqa_dset_item = deepcopy(coqa_dset_item)
                    coqa_dset_item['background'] = perturbation_router(
                        text=coqa_dset_item['background'],
                        strength=strength,
                        rng=rng,
                        ptype=self.perturb_type,
                    )
                    coqa_dset_item['perturb_group_idx'] = sample_id
                    coqa_dset_item['perturb_strength'] = strength
                    
                    # record in the group map
                    self.group_idxs_map[sample_id].append(len(self.data))

                    # add to the list
                    self.data.append(coqa_dset_item)


    
    def __getitem__(self, item):
        # returns structured for our rollout scripts
        dset_item = self.full_item(item)
        return dset_item
    
    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def get_perturbation_strength(self, item) -> float:
        dset_item = self.data[item]
        return dset_item['perturb_strength']

    def get_perutrbation_group_ids(self, item) -> List:
        group_idx = self.data[item]['perturb_group_idx']
        return self.group_idxs_map[group_idx]

    def get_problem_system_instruction(self):
        return self.coqa_data.get_problem_system_instruction()


class SQUADPerturb(DatasetWithPerturbation):
    def __init__(
        self,
        perturb_strength=[0.,0.1],
        perturb_type='shuffle', # shuffle, obscure
        perturb_dynamic=False,
        slim=1024,
        **squad_kwargs,
    ):
        self.coqa_data = SQUADv2(**squad_kwargs)
        
        self.perturb_type=perturb_type
        self.perturb_dynamic=perturb_dynamic
        self.perturb_strength=perturb_strength

        if not perturb_dynamic:
            # build up the static dataset from CoQA
            self.data = []
            self.group_idxs_map = defaultdict(list)
            rng = RandomState(seed=42)
            # self.data_perturb_group = []
            # self.data_perturb_strengths = []
            for sample_id, coqa_dset_item in enumerate(self.coqa_data):
                if sample_id >= slim:
                    break
                
                for p_id, strength in enumerate(perturb_strength):
                    coqa_dset_item = deepcopy(coqa_dset_item)
                    coqa_dset_item['background'] = perturbation_router(
                        text=coqa_dset_item['background'],
                        strength=strength,
                        rng=rng,
                        ptype=self.perturb_type,
                    )
                    coqa_dset_item['perturb_group_idx'] = sample_id
                    coqa_dset_item['perturb_strength'] = strength
                    
                    # record in the group map
                    self.group_idxs_map[sample_id].append(len(self.data))

                    # add to the list
                    self.data.append(coqa_dset_item)


    
    def __getitem__(self, item):
        # returns structured for our rollout scripts
        dset_item = self.full_item(item)
        return dset_item
    
    def full_item(self, item):
        return deepcopy(self.data[item])

    def __len__(self):
        return len(self.data)

    def get_perturbation_strength(self, item) -> float:
        dset_item = self.data[item]
        return dset_item['perturb_strength']

    def get_perutrbation_group_ids(self, item) -> List:
        group_idx = self.data[item]['perturb_group_idx']
        return self.group_idxs_map[group_idx]

    def get_problem_system_instruction(self):
        return self.coqa_data.get_problem_system_instruction()