import torch
import torch.utils.data as Data
import numpy as np
import random
from .data_generator_base import *

def single_func_new(args,x, single_prompt):
        p_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        # diff = [1, 2, 3, 4]
        if args.composition_mode == 'default':
            # diff = [5, 1, -2, -8, 3, 4, -1, -6, 2, -4]
            diff = [5, 1, -2, -8, 3, 4, -7, -6, 2, -4]
        elif args.composition_mode == 'positive':
            diff = [4,1,2,6]
        elif args.composition_mode == 'mode1':
            diff = [1,-1]
        elif args.composition_mode == 'mode2':
            diff = [1,-2]
        else:
            raise ValueError('composition_mode should be default or positive')
        i = p_list.index(single_prompt)
        return x + diff[i]

def task_composition_1backward(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    remainder_dict = generate_list(args.data_min, args.data_max)

    for i in range(data_size):
        a1 = int(mode[0])


        pos = np.random.randint(0, args.seq_len-1)


        x = random.choice(remainder_dict)

        seq_list[i][pos], seq_list[i][pos+1] = x, a1

        y = single_func_new(args,x, a1)
        seq_list[i][-1] = y
    
    return seq_list

def task_composition_1forward(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    remainder_dict = generate_list(args.data_min, args.data_max)

    for i in range(data_size):
        a1 = int(mode[0])


        pos = np.random.randint(1, args.seq_len)

        x = random.choice(remainder_dict)

        seq_list[i][pos-1], seq_list[i][pos] = a1, x

        y = single_func_new(args,x, a1)
        seq_list[i][-1] = y
    
    return seq_list

def task_composition_2backward(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()


    remainder_dict = generate_list(args.data_min, args.data_max)
    
    
    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])

 
        pos = np.random.randint(0, args.seq_len-2)


        x = random.choice(remainder_dict)
            
        seq_list[i][pos], seq_list[i][pos+1], seq_list[i][pos+2] = x, a1, a2

        tmp = single_func_new(args, x, a1)
        y = single_func_new(args, tmp, a2)
        seq_list[i][-1] = y

    return seq_list

def task_composition_2forward(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()


    remainder_dict = generate_list(args.data_min, args.data_max)
    
    
    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])


        pos = np.random.randint(2, args.seq_len)


        x = random.choice(remainder_dict)
            
        seq_list[i][pos-2], seq_list[i][pos-1], seq_list[i][pos] = a2, a1, x

        tmp = single_func_new(args, x, a1)
        y = single_func_new(args, tmp, a2)
        seq_list[i][-1] = y

    return seq_list

def task_composition_3backward(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    remainder_dict = generate_list(args.data_min, args.data_max)

    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])
        a3 = int(mode[2])


        pos = np.random.randint(0, args.seq_len-3)


        x = random.choice(remainder_dict)
            
        seq_list[i][pos], seq_list[i][pos+1], seq_list[i][pos+2], seq_list[i][pos+3] = x, a1, a2, a3

        tmp = single_func_new(args,x, a1)
        tmp1 = single_func_new(args,tmp, a2)
        y = single_func_new(args,tmp1, a3)

        seq_list[i][-1] = y


    return seq_list

def task_composition_3forward(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    remainder_dict = generate_list(args.data_min, args.data_max)

    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])
        a3 = int(mode[2])


        pos = np.random.randint(3, args.seq_len)


        x = random.choice(remainder_dict)
            
        seq_list[i][pos-3], seq_list[i][pos-2], seq_list[i][pos-1], seq_list[i][pos] = a3, a2, a1, x

        tmp = single_func_new(args,x, a1)
        tmp1 = single_func_new(args,tmp, a2)
        y = single_func_new(args,tmp1, a3)

        seq_list[i][-1] = y


    return seq_list

def task_composition_backward(args, mode, data_size):

    if len(mode) == 5:
        seq_list = task_composition_1backward(args, mode, data_size)
    elif len(mode) == 6:
        seq_list = task_composition_2backward(args, mode, data_size)
    elif len(mode) == 7:
        seq_list = task_composition_3backward(args, mode, data_size)
    return seq_list

def task_composition_forward(args, mode, data_size):

    if len(mode) == 5:
        seq_list = task_composition_1forward(args, mode, data_size)
    elif len(mode) == 6:
        seq_list = task_composition_2forward(args, mode, data_size)
    elif len(mode) == 7:
        seq_list = task_composition_3forward(args, mode, data_size)
    return seq_list

