import os
import re
import copy
import yaml
import glob
import time
import json
import random
import logging
import linecache

import numpy as np
import torch

from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import pandas as pd

from transformers import AutoTokenizer
from collections import OrderedDict, defaultdict

    
def ordered_yaml():
    """Support OrderedDict for yaml.

    Returns:
        yaml Loader and Dumper.
    """
    try:
        from yaml import CDumper as Dumper
        from yaml import CLoader as Loader
    except ImportError:
        from yaml import Dumper, Loader

    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG

    def dict_representer(dumper, data):
        return dumper.represent_dict(data.items())

    def dict_constructor(loader, node):
        return OrderedDict(loader.construct_pairs(node))

    Dumper.add_representer(OrderedDict, dict_representer)
    Loader.add_constructor(_mapping_tag, dict_constructor)
    return Loader, Dumper

class EasyDict(dict):
    def __init__(self, d=None, **kwargs):
        if d is None:
            d = {}
        else:
            d = dict(d)
        if kwargs:
            d.update(**kwargs)
        for k, v in d.items():
            setattr(self, k, v)
        # Class attributes
        for k in self.__class__.__dict__.keys():
            if not (k.startswith('__') and k.endswith('__')) and k not in ('update', 'pop'):
                setattr(self, k, getattr(self, k))

    def __setattr__(self, name, value):
        if isinstance(value, (list, tuple)):
            value = type(value)(self.__class__(x)
                     if isinstance(x, dict) else x for x in value)
        elif isinstance(value, dict) and not isinstance(value, EasyDict):
            value = EasyDict(value)
        super(EasyDict, self).__setattr__(name, value)
        super(EasyDict, self).__setitem__(name, value)

    __setitem__ = __setattr__

    def update(self, e=None, **f):
        d = e or dict()
        d.update(f)
        for k in d:
            setattr(self, k, d[k])

    def pop(self, k, *args):
        if hasattr(self, k):
            delattr(self, k)
        return super(EasyDict, self).pop(k, *args)

class EasyYml:
    def __new__(cls, path):
        with open(path, mode='r') as f:
            Loader, _ = ordered_yaml()
            opt = yaml.load(f, Loader=Loader)
        return EasyDict(opt) 
    

class GenEvalDataset(Dataset):
    def __init__(self, config):
        super().__init__()
        
        self.meta_list = []
        self.pathh = config.datap[0]
        self.totnum = len(linecache.getlines(self.pathh))
        print('total num:', self.totnum)
        for i in range(self.totnum):
            jsonstr = linecache.getline(self.pathh, i+1)
            self.meta_list.append(json.loads(jsonstr))
    
    def __getitem__(self, idx):
       
        return {
            'text':self.meta_list[idx]["prompt"],
            'meta':self.meta_list[idx]
        }
    
    def __len__(self):
        return self.totnum
    

class DPGDataset(Dataset):
    def __init__(self, config):
        super().__init__()
        self.data_source = config.data_source[0]
        if len(config.data_source) > 1 and self.data_source != 'dpg':
            raise ValueError(f'Only dpg dataset is support in DPGDataset! But found {config.data_source}')
        
        self.data_path = config.datap[0]

        if len(config.datap) > 1:
            raise ValueError('Found data source > 1 in DPGDataset')
        
        self.data_dict = self.prepare_dpg_data_dict(self.data_path)
        
        self.key_list = list(self.data_dict.keys())

        self.prompt_dict = self.prepare_dpg_prompt_dict(self.data_path)

        self.data_num = len(self.key_list)
        
    def __len__(self):
        return self.data_num
    
    def __getitem__(self, idx):
        key = self.key_list[idx]
        data_item = {
            'text':self.prompt_dict[key],
            'meta':self.data_dict[key]
        }
        return data_item

    def prepare_dpg_prompt_dict(self, csv_path):
        df = pd.read_csv(csv_path)
        
        df_unique = df.drop_duplicates(subset=['item_id'], keep='first')

        item_prompt_map = df_unique.set_index('item_id')['text'].to_dict()

        return item_prompt_map


    def prepare_dpg_data_dict(self, csv_path):
        previous_ids = []
        current_id = ''
        question_dict = dict()
        category_count = defaultdict(int)
        # 'item_id', 'text', 'keywords', 'proposition_id', 'dependency', 'category_broad', 'category_detailed', 'tuple', 'question_natural_language'
        data = pd.read_csv(csv_path)
        for i, line in data.iterrows():
            #if i == 0: # Is a original code error. Uncommenting this will remove the first question of the dpg prompts.
            #    continue

            current_id = line.item_id
            qid = int(line.proposition_id)
            dependency_list_str = line.dependency.split(',')
            dependency_list_int = []
            for d in dependency_list_str:
                d_int = int(d.strip())
                dependency_list_int.append(d_int)

            if current_id in previous_ids:
                question_dict[current_id]['qid2tuple'][qid] = line.tuple
                question_dict[current_id]['qid2dependency'][qid] = dependency_list_int
                question_dict[current_id]['qid2question'][qid] = line.question_natural_language
            else:
                question_dict[current_id] = dict(
                    qid2tuple={qid: line.tuple},
                    qid2dependency={qid: dependency_list_int},
                    qid2question={qid: line.question_natural_language})
                previous_ids.append(current_id)
            
            category = line.question_natural_language.split('(')[0].strip()
            category_count[category] += 1

        return question_dict