import torch
import torch.utils.data as Data
import numpy as np
import random
from data_generator import *

class MyDataSet(Data.Dataset):
    def __init__(self,datas):
        self.datas = datas

    def __getitem__(self, item):
        data = self.datas[item]
        decoder_input = data[:-1]
        decoder_output = data[1:]

        decoder_input_len = len(decoder_input)
        decoder_output_len = len(decoder_output)

        return {"decoder_input": decoder_input, "decoder_input_len": decoder_input_len,
                "decoder_output": decoder_output, "decoder_output_len": decoder_output_len}

    def __len__(self):
        return len(self.datas)

    def padding_batch(self, batch):
        decoder_inputs = torch.tensor([d["decoder_input"] for d in batch], dtype=torch.long)
        decoder_outputs = torch.tensor([d["decoder_output"] for d in batch], dtype=torch.long)

        return decoder_inputs, decoder_outputs
    

def generate_random_list(seq_len=7, data_min=20, data_max=100):
    return [random.randint(data_min, data_max) for _ in range(seq_len)]


def generate_mod_list(data_min=20, data_max=100, mod=8):
    train_lst, test_lst = [], []
    for i in range(data_min, data_max):
        if i % mod == 0:
            test_lst.append(i)
        else: 
            train_lst.append(i)

    return train_lst, test_lst

def generate_mod_list_specific(data_min=20, data_max=100, mod=8):
    train_lst, test_lst = {}, {}
    for mod_num in range(mod):
        mod_num_str = str(mod_num)
        train_lst[mod_num_str] = []
        test_lst[mod_num_str] = []
        for i in range(data_min, data_max):
            if i % mod == mod_num:
                test_lst[mod_num_str].append(i)
            else: 
                train_lst[mod_num_str].append(i)

    return train_lst, test_lst


def generate_sequence(args, dataset, mode=1, **kwargs):

    seq = generate_random_list(args.seq_len+1, args.data_min, args.data_max)

    if args.target == 'single_chain_search':
        seq = task_single_chain_search(mode=mode)
        
    return seq


def get_data(args, return_dict=False, **kwargs):
    
    if return_dict:
        datas = get_train_data(args, True, **kwargs)
        datas.update(get_test_data(args, True, **kwargs))
        return datas
    else:
        return get_train_data(args, False, **kwargs), get_test_data(args, False, **kwargs)
    


def get_train_data(args, return_dict=False, **kwargs):
    if kwargs and 'use_mod_list_specific' in kwargs and bool(kwargs['use_mod_list_specific']):
        variable_train_lst, variable_test_lst = generate_mod_list_specific(args.data_min, args.data_max, args.seq_len-1)
    else:
        variable_train_lst, variable_test_lst = generate_mod_list(args.data_min, args.data_max, args.seq_len-1)

    percent_list = np.array(args.data_percent)
    percent_list = percent_list / np.sum(percent_list)
    percent_list = percent_list.tolist()

    train_seq_list = []
    train_seq_group = {}
    for percent, mode, name, mask in zip(percent_list, args.data_mode, args.data_name, args.data_mask):
        
        # tmp_train_seq_list = [generate_sequence(args, variable_train_lst, mode, **kwargs) for _ in range(math.ceil(args.train_data_size * percent))]

        # if type(tmp_train_seq_list[0][0]) == list:
        #     tmp_train_seq_list = [item for sublist in tmp_train_seq_list for item in sublist]

        tmp_dir = 'result/GPT_2_step_reasoning/single_chain_search-seed_1-N_200000-3L1H_shown_in_paper'
        data1 = np.load(f'{tmp_dir}/data/train.npz')
        train_seq_group = dict(data1)
        tmp_train_seq_list = train_seq_group[name].tolist()

        # train_seq_group[name] = list(tmp_train_seq_list)
    
        if mask == 0:
            train_seq_list = train_seq_list + tmp_train_seq_list
    
    train_seq_list = np.array(train_seq_list)

    train_dataset = MyDataSet(train_seq_list)
    train_data_loader = Data.DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, 
                                        drop_last=True, collate_fn=train_dataset.padding_batch)
    
    if return_dict:
        datas = {'train_data_loader': train_data_loader, 'train_seq_group': train_seq_group, 'train_seq_list': train_seq_list}
        return datas
    else:
        return train_data_loader, train_seq_group, train_seq_list
    

def get_test_data(args, return_dict=False, **kwargs):
    if kwargs and 'use_mod_list_specific' in kwargs and bool(kwargs['use_mod_list_specific']):
        _, variable_test_lst = generate_mod_list_specific(args.data_min, args.data_max, args.seq_len-1)
    else:
        _, variable_test_lst = generate_mod_list(args.data_min, args.data_max, args.seq_len-1)

    percent_list = np.array(args.data_percent)
    percent_list = percent_list / np.sum(percent_list)
    percent_list = percent_list.tolist()

    test_seq_list = []
    test_seq_group = {}
    for percent, mode, name, mask in zip(percent_list, args.data_mode, args.data_name, args.data_mask):
            
            if args.test_data_size == 0:
                break
    
            # tmp_test_seq_list = [generate_sequence(args, variable_test_lst, mode, **kwargs) for _ in range(math.ceil(args.test_data_size * percent))]
    
            # if type(tmp_test_seq_list[0][0]) == list:
            #     tmp_test_seq_list = [item for sublist in tmp_test_seq_list for item in sublist]

            tmp_dir = 'result/GPT_2_step_reasoning/single_chain_search-seed_1-N_200000-3L1H_shown_in_paper'
            data1 = np.load(f'{tmp_dir}/data/test.npz')
            test_seq_group = dict(data1)
            tmp_test_seq_list = test_seq_group[name].tolist()
    
            # test_seq_group[name] = list(tmp_test_seq_list)
            if mask == 0:
                test_seq_list = test_seq_list + tmp_test_seq_list

    test_seq_list = np.array(test_seq_list)

    test_dataset = MyDataSet(test_seq_list)
    if args.test_data_size == 0:
        test_data_loader = None
    else:
        test_data_loader = Data.DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, 
                                       drop_last=True, collate_fn=test_dataset.padding_batch)
        
    if return_dict:
        datas = {'test_data_loader': test_data_loader, 'test_seq_group': test_seq_group, 'test_seq_list': test_seq_list}
        return datas
    else:
        return test_data_loader, test_seq_group, test_seq_list




def load_data(args, return_dict=False, **kwargs):
    train_seq_group = np.load(f'{args.working_dir}/data/train.npz')
    test_seq_group = np.load(f'{args.working_dir}/data/test.npz')

    percent_list = np.array(args.data_percent)
    percent_list = percent_list / np.sum(percent_list)
    percent_list = percent_list.tolist()

    
    test_seq_list = []
    for percent, mode, name, mask in zip(percent_list, args.data_mode, args.data_name, args.data_mask):
        if mask == 0:
            test_seq_list = test_seq_list + list(test_seq_group[name])

    
    train_seq_list = []
    for percent, mode, name, mask in zip(percent_list, args.data_mode, args.data_name, args.data_mask):
        if mask == 0:
            train_seq_list = train_seq_list + list(test_seq_group[name])


    test_seq_list, train_seq_list = np.array(test_seq_list), np.array(train_seq_list)

    train_dataset = MyDataSet(train_seq_list)
    train_data_loader = Data.DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, 
                                        drop_last=True, collate_fn=train_dataset.padding_batch)

    test_dataset = MyDataSet(test_seq_list)
    test_data_loader = Data.DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, 
                                       drop_last=True, collate_fn=test_dataset.padding_batch)
    
    if return_dict:
        datas = {'train_data_loader': train_data_loader, 'test_data_loader': test_data_loader, 
                'train_seq_group': train_seq_group, 'test_seq_group': test_seq_group, 
                'train_seq_list': train_seq_list, 'test_seq_list': test_seq_list}
        return datas
    else:
        return train_data_loader, test_data_loader, train_seq_group, test_seq_group, train_seq_list, test_seq_list
