import torch
import torch.utils.data as Data
import numpy as np
import random
from .data_generator_base import *


def single_func(args,x, single_prompt):
        p_list = [1, 2, 3, 4, 5, 6, 7, 8]
        # diff = [1, 2, 3, 4]
        if args.composition_mode == 'default':
            diff = [5, 1, -2, -8, 3, -4, -7, 6]
        elif args.composition_mode == 'positive':
            diff = [4,1,2,6]
        else:
            raise ValueError('composition_mode should be default or positive')
        i = p_list.index(single_prompt)
        return x + diff[i]


def task_composition_1(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    train_remainder_dict, test_remainder_dict = generate_mod_list(args.data_min, args.data_max, args.seq_len)

    for i in range(data_size):
        a1 = int(mode[0])


        pos = np.random.randint(0, args.seq_len-1)

        if mode[-3:] == 'xel':
            x = random.choice(train_remainder_dict[str(pos % args.seq_len)])
        elif mode[-3:] == 'xm0':
            x = random.choice(test_remainder_dict[str(pos % args.seq_len)])

        seq_list[i][pos], seq_list[i][pos+1] = x, a1

        y = single_func(args,x, a1)
        seq_list[i][-1] = y
    
    return seq_list


def task_composition(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    train_remainder_dict, test_remainder_dict = generate_mod_list(args.data_min, args.data_max, args.seq_len)

    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])


        pos = np.random.randint(0, args.seq_len-2)

        if mode[-3:] == 'xel':
            x = random.choice(train_remainder_dict[str(pos % args.seq_len)])
        elif mode[-3:] == 'xm0':
            x = random.choice(test_remainder_dict[str(pos % args.seq_len)])
            
        seq_list[i][pos], seq_list[i][pos+1], seq_list[i][pos+2] = x, a1, a2

        tmp = single_func(args,x, a1)
        y = single_func(args,tmp, a2)
        seq_list[i][-1] = y

        # if a1 ==3 and a2 == 4:
        #     seq_list[i][-1] +=4

    return seq_list

def task_composition_3(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    train_remainder_dict, test_remainder_dict = generate_mod_list(args.data_min, args.data_max, args.seq_len)

    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])
        a3 = int(mode[2])


        pos = np.random.randint(0, args.seq_len-3)

        if mode[-3:] == 'xel':
            x = random.choice(train_remainder_dict[str(pos % args.seq_len)])
        elif mode[-3:] == 'xm0':
            x = random.choice(test_remainder_dict[str(pos % args.seq_len)])
            
        seq_list[i][pos], seq_list[i][pos+1], seq_list[i][pos+2], seq_list[i][pos+3] = x, a1, a2, a3

        tmp = single_func(args,x, a1)
        tmp1 = single_func(args,tmp, a2)
        y = single_func(args,tmp1, a3)

        seq_list[i][-1] = y

        # if a1 == 1 and a2 == 2 and a3 == 3:
        #     seq_list[i][-1] +=4

    return seq_list

def task_composition_4(args, mode, data_size):

    seq_array = np.random.randint(args.data_min, args.data_max, size=(data_size, args.seq_len+1))
    seq_list = seq_array.tolist()

    train_remainder_dict, test_remainder_dict = generate_mod_list(args.data_min, args.data_max, args.seq_len)

    for i in range(data_size):
        a1 = int(mode[0])
        a2 = int(mode[1])
        a3 = int(mode[2])
        a4 = int(mode[3])


        pos = np.random.randint(0, args.seq_len-4)

        if mode[-3:] == 'xel':
            x = random.choice(train_remainder_dict[str(pos % args.seq_len)])
        elif mode[-3:] == 'xm0':
            x = random.choice(test_remainder_dict[str(pos % args.seq_len)])
            
        seq_list[i][pos], seq_list[i][pos+1], seq_list[i][pos+2], seq_list[i][pos+3], seq_list[i][pos+4] = x, a1, a2, a3, a4

        tmp = single_func(args,x, a1)
        tmp1 = single_func(args,tmp, a2)
        tmp2 = single_func(args,tmp1, a3)
        y = single_func(args,tmp2, a4)

        seq_list[i][-1] = y

    return seq_list


def task_composition_1and2(args, mode, data_size):

    if len(mode) == 5:
        seq_list = task_composition_1(args, mode, data_size)
    elif len(mode) == 6:
        seq_list = task_composition(args, mode, data_size)
    return seq_list

def task_composition_all(args, mode, data_size):

    if len(mode) == 5:
        seq_list = task_composition_1(args, mode, data_size)
    elif len(mode) == 6:
        seq_list = task_composition(args, mode, data_size)
    elif len(mode) == 7:
        seq_list = task_composition_3(args, mode, data_size)
    elif len(mode) == 8:
        seq_list = task_composition_4(args, mode, data_size)
    return seq_list
