import torch
import torch.utils.data as Data
import numpy as np
import random
from .data_generator_base import *


def single_func(x, single_prompt):
        p_list = [1, 2, 3, 4]
        # diff = [1, 2, 3, 4]
        diff = [5, 1, -2, -8]
        i = p_list.index(single_prompt)
        return x + diff[i]





def task_composition(args, mode, data_size):
    # 生成大小为data_size * (args.seq_len+1)的矩阵存储句子
    # 每个元素随机选自args.data_min ~ args.data_max
    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    train_remainder_dict, test_remainder_dict = generate_mod_list(args.data_min, args.data_max, args.seq_len)

    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])

        # 随机选取一个位置，将该位置的数替换成a1，下一位替换成a2
        pos = np.random.randint(0, args.seq_len-2)

        if mode[-3:] == 'xel':
            x = random.choice(train_remainder_dict[str(pos % args.seq_len)])
        elif mode[-3:] == 'xm0':
            x = random.choice(test_remainder_dict[str(pos % args.seq_len)])
            
        seq_list[i][pos], seq_list[i][pos+1], seq_list[i][pos+2] = x, a1, a2

        tmp = single_func(x, a2)
        y = single_func(tmp, a1)
        seq_list[i][-1] = y

    return seq_list


def task_similar_token(args, data_size):
    '''
        生成任务similar_token的数据
    '''
    # 生成大小为data_size * (args.seq_len+1)的矩阵存储句子
    # 每个元素随机选自args.data_min ~ args.data_max
    anchor_list = np.arange(1, 2*args.anchor_num + 1)
    assert args.data_min > 2*args.anchor_num, "data_min should be greater than or equal to 2*anchor_num"
    
    step = 0
    seq_list = []; pos_list = []
    while step < data_size:
        seq = np.random.randint(args.data_min, args.data_max - 10, size=(args.seq_len+1))
        pos = np.random.randint(1, args.seq_len-1)
        x = random.choice(anchor_list)

        seq[pos] = x 
        seq[-1] = seq[pos-1] + (x % args.anchor_num)
        if seq[-1] < args.data_max:
            step += 1
            seq_list.append(seq)
            pos_list.append(-1)
    
    return seq_list

def task_binary_classification(args, data_size):
    '''
        生成二分类任务的数据
    '''
    # 生成大小为data_size * (args.seq_len+1)的矩阵存储句子
    # 每个元素随机选自args.data_min ~ args.data_max
    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    for i in range(data_size):
        if seq_list[i][-2] % 2 == 0:
            seq_list[i][-1] = 0
        else:
            seq_list[i][-1] = 1
    return seq_list


            
        