import functools
import math
from typing import List, Tuple

import numpy
import torch
from OptimizationTestFunctions import (
    Michalewicz,
    Transformation,
    SchwefelSin,
    Weierstrass,
    Sphere,
    Rastrigin,
)

from algorithms.space.optimization_test_function import OptimizationTestSpace
from algorithms.space.utils import COCO_SUITE, HIGH_DIM_COCO_SUITE
from problems.types import Benchmarks, PYTHONIC_SUITE
from problems.utils import (
    BUDGET,
    coco_space_from_funcnum_and_dim,
    coco_space_from_index,
    coco_largescale_space_from_index,
)


def opt_gan_benchmark(dim, budget: int = BUDGET):
    return [
        OptimizationTestSpace(
            Michalewicz(),
            torch.tensor([0] * dim),
            torch.tensor([4] * dim),
            budget=budget,
        ),
        OptimizationTestSpace(
            Transformation(
                SchwefelSin(dim), shift_step=numpy.array([2] * dim), rotation_matrix=dim
            ),
            torch.tensor([-100] * dim),
            torch.tensor([100] * dim),
            budget=budget,
        ),
        OptimizationTestSpace(
            Transformation(
                Weierstrass(dim), shift_step=numpy.array([2] * dim), rotation_matrix=dim
            ),
            torch.tensor([-100] * dim),
            torch.tensor([100] * dim),
            budget=budget,
        ),
        OptimizationTestSpace(
            Sphere(dim),
            torch.tensor([-5] * dim),
            torch.tensor([5] * dim),
            budget=budget,
        ),
        OptimizationTestSpace(
            Rastrigin(dim),
            torch.tensor([-5] * dim),
            torch.tensor([5] * dim),
            budget=budget,
        ),
        coco_space_from_funcnum_and_dim(7, dim, 1, budget),
    ]


def high_dim_benchmark(budget: int = BUDGET, dim: int = 1000):
    return [
        OptimizationTestSpace(
            func(dim),
            torch.tensor([-5] * dim),
            torch.tensor([5] * dim),
            budget=budget,
        )
        for func in PYTHONIC_SUITE
    ]


def coco_benchmark(budget: int = BUDGET):
    return [coco_space_from_index(i, budget) for i in range(len(COCO_SUITE))]


def coco_largescale_benchmark(budget: int = BUDGET):
    return [
        env
        for i in range(len(HIGH_DIM_COCO_SUITE))
        if (env := coco_largescale_space_from_index(i, budget))
        if env.dimension > 40
    ]


def coco_partial_largescale_benchmark(budget: int = BUDGET):
    return [coco_largescale_space_from_index(i, budget) for i in range(600, 1400, 10)]


def coco_partial_largescale_benchmark_2(budget: int = BUDGET):
    return [coco_largescale_space_from_index(i, budget) for i in range(1200, 2000, 10)]


def partial_coco_benchmark(budget: int = BUDGET):
    return [coco_space_from_index(i, budget) for i in range(1200, 2000, 10)]


def partial_coco_40_benchmark(budget: int = BUDGET):
    func_instances = [1, 3, 96, 100]
    return [
        coco_space_from_funcnum_and_dim(i, 40, instance, budget)
        for i in range(1, 25)
        for instance in func_instances
    ]


def partial_coco_low_benchmark(budget: int = BUDGET):
    func_instances = [3, 96]
    return [
        coco_space_from_funcnum_and_dim(i, 2, instance, budget)
        for i in range(1, 25)
        for instance in func_instances
    ] + [
        coco_space_from_funcnum_and_dim(i, 3, instance, budget)
        for i in range(1, 25)
        for instance in func_instances
    ]


def create_problems_from_instances(problems: List[Tuple[int, int, int]], budget: int = BUDGET):
    return [coco_space_from_funcnum_and_dim(*problem, budget) for problem in problems]


def partial_unsolved_benchmark(budget: int = BUDGET):
    func_15_unsolved_problems = [
        (15, 10, 97),
        (15, 3, 93),
        (15, 40, 96),
        (15, 2, 96),
        (15, 20, 5),
        (15, 40, 99),
        (15, 5, 96),
        (15, 20, 93),
        (15, 10, 1),
        (15, 3, 95),
        (15, 20, 4),
        (15, 40, 98),
        (15, 5, 98),
        (15, 20, 95),
        (15, 3, 94),
        (15, 40, 2),
        (15, 3, 97),
        (15, 40, 100),
        (15, 10, 91),
        (15, 5, 97),
        (15, 20, 94),
        (15, 5, 100),
        (15, 20, 97),
        (15, 10, 2),
        (15, 3, 96),
        (15, 10, 5),
        (15, 5, 1),
        (15, 10, 93),
        (15, 5, 99),
        (15, 40, 92),
        (15, 20, 96),
        (15, 20, 99),
        (15, 10, 4),
        (15, 40, 3),
        (15, 3, 98),
        (15, 40, 91),
        (15, 10, 95),
        (15, 20, 98),
        (15, 5, 91),
        (15, 40, 5),
        (15, 5, 2),
        (15, 10, 94),
        (15, 40, 93),
        (15, 20, 100),
        (15, 5, 93),
        (15, 3, 1),
        (15, 40, 4),
        (15, 3, 99),
        (15, 5, 4),
        (15, 10, 96),
        (15, 3, 92),
        (15, 40, 95),
        (15, 5, 92),
        (15, 5, 95),
        (15, 20, 92),
        (15, 3, 91),
        (15, 40, 94),
        (15, 40, 97),
        (15, 5, 94),
        (15, 20, 91),
        (15, 2, 5),
        (15, 3, 5),
    ]
    func_16_unsolved_problems = [
        (16, 40, 3),
        (16, 40, 91),
        (16, 20, 98),
        (16, 10, 2),
        (16, 40, 2),
        (16, 10, 5),
        (16, 40, 5),
        (16, 10, 93),
        (16, 5, 100),
        (16, 40, 93),
        (16, 20, 97),
        (16, 20, 100),
        (16, 40, 4),
        (16, 40, 92),
        (16, 5, 4),
        (16, 10, 95),
        (16, 20, 1),
        (16, 40, 95),
        (16, 20, 99),
        (16, 5, 92),
        (16, 3, 91),
        (16, 40, 94),
        (16, 20, 3),
        (16, 20, 91),
        (16, 3, 93),
        (16, 20, 2),
        (16, 40, 96),
        (16, 20, 5),
        (16, 20, 93),
        (16, 10, 98),
        (16, 20, 4),
        (16, 40, 98),
        (16, 5, 95),
        (16, 20, 92),
        (16, 20, 95),
        (16, 10, 97),
        (16, 40, 97),
        (16, 40, 100),
        (16, 20, 94),
        (16, 40, 1),
        (16, 10, 99),
        (16, 40, 99),
        (16, 5, 99),
        (16, 20, 96),
    ]
    func_21_unsolved_problems = [
        (21, 5, 96),
        (21, 3, 4),
        (21, 10, 1),
        (21, 10, 99),
        (21, 5, 95),
        (21, 40, 97),
        (21, 2, 97),
        (21, 40, 100),
        (21, 5, 97),
        (21, 10, 2),
        (21, 5, 1),
        (21, 20, 96),
        (21, 3, 95),
        (21, 40, 91),
        (21, 20, 95),
        (21, 5, 91),
        (21, 40, 2),
        (21, 3, 97),
        (21, 2, 2),
        (21, 3, 100),
        (21, 5, 2),
        (21, 5, 100),
        (21, 20, 97),
        (21, 20, 100),
        (21, 2, 1),
        (21, 3, 1),
        (21, 2, 4),
        (21, 10, 96),
        (21, 2, 92),
        (21, 20, 99),
        (21, 5, 92),
        (21, 3, 3),
        (21, 5, 3),
        (21, 2, 91),
        (21, 3, 91),
        (21, 40, 94),
        (21, 10, 98),
        (21, 2, 94),
        (21, 20, 3),
        (21, 5, 94),
        (21, 20, 91),
        (21, 40, 5),
        (21, 5, 5),
        (21, 10, 97),
        (21, 20, 2),
        (21, 5, 93),
        (21, 2, 96),
    ]
    func_22_unsolved_problems = [
        (22, 40, 2),
        (22, 5, 96),
        (22, 20, 97),
        (22, 40, 1),
        (22, 10, 99),
        (22, 10, 92),
        (22, 5, 98),
        (22, 40, 92),
        (22, 20, 96),
        (22, 10, 3),
        (22, 40, 3),
        (22, 3, 98),
        (22, 10, 91),
        (22, 40, 91),
        (22, 2, 100),
        (22, 5, 100),
        (22, 20, 98),
        (22, 10, 5),
        (22, 2, 1),
        (22, 2, 99),
        (22, 3, 100),
        (22, 40, 93),
        (22, 20, 2),
        (22, 5, 92),
        (22, 3, 1),
        (22, 2, 3),
        (22, 5, 3),
        (22, 2, 91),
        (22, 3, 3),
        (22, 10, 94),
        (22, 5, 5),
        (22, 10, 97),
        (22, 2, 93),
        (22, 20, 3),
        (22, 5, 93),
        (22, 20, 91),
        (22, 3, 5),
        (22, 5, 4),
        (22, 2, 92),
        (22, 3, 93),
        (22, 20, 5),
        (22, 5, 95),
        (22, 3, 4),
        (22, 5, 94),
        (22, 5, 97),
    ]
    return create_problems_from_instances(
        func_15_unsolved_problems
        + func_16_unsolved_problems
        + func_21_unsolved_problems
        + func_22_unsolved_problems,
        budget,
    )


def high_unsolved_benchmark(budget: int = BUDGET):
    problems = [
        (3, 40, 1),
        (3, 20, 3),
        (3, 40, 92),
        (3, 20, 91),
        (3, 40, 3),
        (3, 20, 2),
        (3, 40, 91),
        (3, 20, 93),
        (3, 40, 5),
        (3, 40, 93),
        (3, 20, 92),
        (3, 20, 95),
        (3, 40, 4),
        (3, 40, 95),
        (3, 20, 94),
        (3, 20, 97),
        (3, 40, 94),
        (3, 40, 97),
        (3, 20, 96),
        (3, 40, 96),
        (3, 40, 99),
        (3, 20, 98),
        (3, 40, 98),
        (3, 20, 100),
        (3, 40, 2),
        (3, 20, 1),
        (3, 40, 100),
        (4, 40, 92),
        (4, 40, 95),
        (4, 20, 94),
        (4, 40, 94),
        (4, 20, 93),
        (4, 20, 98),
        (4, 20, 97),
        (4, 40, 99),
        (4, 20, 91),
        (4, 40, 91),
        (4, 40, 2),
        (4, 40, 5),
        (4, 20, 92),
        (4, 40, 4),
        (15, 40, 98),
        (15, 20, 97),
        (15, 20, 100),
        (15, 40, 2),
        (15, 40, 100),
        (15, 20, 99),
        (15, 40, 99),
        (15, 40, 92),
        (15, 20, 91),
        (15, 40, 3),
        (15, 40, 91),
        (15, 20, 5),
        (15, 20, 93),
        (15, 40, 5),
        (15, 20, 4),
        (15, 40, 93),
        (15, 20, 92),
        (15, 20, 95),
        (15, 40, 4),
        (15, 40, 95),
        (15, 20, 94),
        (15, 40, 94),
        (15, 40, 97),
        (15, 20, 96),
        (15, 40, 96),
        (15, 20, 98),
        (16, 40, 97),
        (16, 40, 96),
        (16, 20, 98),
        (16, 40, 98),
        (16, 20, 97),
        (16, 20, 100),
        (16, 40, 2),
        (16, 20, 1),
        (16, 40, 100),
        (16, 20, 99),
        (16, 40, 1),
        (16, 40, 99),
        (16, 20, 3),
        (16, 40, 92),
        (16, 20, 91),
        (16, 40, 3),
        (16, 20, 2),
        (16, 40, 91),
        (16, 20, 5),
        (16, 20, 93),
        (16, 40, 5),
        (16, 20, 4),
        (16, 40, 93),
        (16, 20, 92),
        (16, 20, 95),
        (16, 40, 4),
        (16, 40, 95),
        (16, 20, 94),
        (16, 40, 94),
        (18, 40, 1),
        (18, 40, 3),
        (18, 40, 91),
        (18, 40, 2),
        (18, 40, 5),
        (18, 40, 4),
        (18, 40, 92),
        (18, 40, 95),
        (18, 40, 94),
        (18, 20, 96),
        (18, 40, 96),
        (18, 40, 97),
        (18, 20, 1),
        (21, 40, 94),
        (21, 20, 96),
        (21, 20, 95),
        (21, 20, 97),
        (21, 40, 97),
        (21, 40, 100),
        (21, 20, 99),
        (21, 20, 91),
        (21, 20, 2),
        (21, 20, 100),
        (21, 40, 91),
        (21, 40, 2),
        (21, 40, 5),
        (21, 20, 3),
        (22, 20, 97),
        (22, 20, 96),
        (22, 20, 98),
        (22, 20, 2),
        (22, 40, 2),
        (22, 40, 1),
        (22, 20, 3),
        (22, 40, 92),
        (22, 20, 91),
        (22, 40, 3),
        (22, 40, 91),
        (22, 20, 5),
        (22, 40, 93),
    ]
    return create_problems_from_instances(problems, budget)


def simple_unsolved_problems(budget: int = BUDGET):
    func_3_unsolved_problems = [
        (3, 3, 3),
        (3, 10, 94),
        (3, 40, 94),
        (3, 20, 3),
        (3, 5, 93),
        (3, 20, 91),
        (3, 3, 5),
        (3, 5, 4),
        (3, 20, 2),
        (3, 40, 96),
        (3, 5, 92),
        (3, 5, 95),
        (3, 3, 4),
        (3, 20, 93),
        (3, 10, 98),
        (3, 40, 98),
        (3, 20, 92),
        (3, 20, 95),
        (3, 10, 97),
        (3, 3, 94),
        (3, 40, 97),
        (3, 40, 100),
        (3, 20, 94),
        (3, 10, 1),
        (3, 40, 1),
        (3, 10, 99),
        (3, 40, 99),
        (3, 5, 98),
        (3, 20, 96),
        (3, 10, 3),
        (3, 40, 3),
        (3, 10, 91),
        (3, 5, 97),
        (3, 40, 91),
        (3, 5, 100),
        (3, 20, 98),
        (3, 10, 2),
        (3, 40, 2),
        (3, 10, 5),
        (3, 2, 1),
        (3, 40, 5),
        (3, 5, 1),
        (3, 10, 93),
        (3, 5, 99),
        (3, 40, 93),
        (3, 20, 97),
        (3, 20, 100),
        (3, 10, 4),
        (3, 40, 4),
        (3, 3, 99),
        (3, 10, 92),
        (3, 40, 92),
        (3, 20, 1),
        (3, 40, 95),
        (3, 5, 91),
    ]
    func_4_unsolved_problems = [
        (4, 5, 96),
        (4, 3, 5),
        (4, 20, 94),
        (4, 10, 1),
        (4, 10, 99),
        (4, 40, 94),
        (4, 20, 93),
        (4, 2, 93),
        (4, 40, 5),
        (4, 2, 4),
        (4, 10, 91),
        (4, 5, 97),
        (4, 20, 98),
        (4, 10, 100),
        (4, 40, 95),
        (4, 20, 97),
        (4, 10, 4),
        (4, 3, 99),
        (4, 5, 91),
        (4, 10, 3),
        (4, 40, 99),
        (4, 10, 5),
        (4, 20, 91),
        (4, 10, 96),
        (4, 40, 91),
        (4, 40, 2),
        (4, 10, 98),
        (4, 2, 99),
        (4, 5, 94),
        (4, 20, 92),
        (4, 40, 4),
        (4, 3, 94),
        (4, 40, 92),
    ]
    return create_problems_from_instances(
        func_3_unsolved_problems + func_4_unsolved_problems, budget
    )


def local_coco_benchmark(budget: int = BUDGET):
    func_nums = [16, 18, 19, 21, 23]
    func_dim = [20, 40]
    func_instance = [1, 91]
    return [
        coco_space_from_funcnum_and_dim(func_num, dim, instance, budget)
        for func_num in func_nums
        for dim in func_dim
        for instance in func_instance
    ]


BENCHMARK_MAPPER = {
    Benchmarks.TEST: lambda b: [coco_space_from_index(0, b)],
    Benchmarks.COCO: coco_benchmark,
    Benchmarks.PARTIAL_COCO: partial_coco_benchmark,
    Benchmarks.PARTIAL_COCO_40: partial_coco_40_benchmark,
    Benchmarks.OPT_GAN: lambda b: [*opt_gan_benchmark(2, b), *opt_gan_benchmark(10, b)],
    Benchmarks.HIGH_DIM: functools.partial(high_dim_benchmark, dim=1000),
    Benchmarks.COCO_HIGH_DIM: coco_largescale_benchmark,
    Benchmarks.PARTIAL_COCO_HIGH_DIM: coco_partial_largescale_benchmark,
    Benchmarks.PARTIAL_COCO_HIGH_DIM_2: coco_partial_largescale_benchmark_2,
    Benchmarks.LOCAL_COCO: local_coco_benchmark,
    Benchmarks.SIMPLE_UNSOLVED: simple_unsolved_problems,
    Benchmarks.PAR_UNSOLVED: partial_unsolved_benchmark,
    Benchmarks.HIGH_UNSOLVED: high_unsolved_benchmark,
    Benchmarks.PAR_COCO_LOW: partial_coco_low_benchmark,
}


def find_problems_to_run(
    benchmark: Tuple[Benchmarks],
    budget: int,
    func_num: List[int] = None,
    parts: List[Tuple[int, int]] = None,
    func_dim: List[int] = None,
    func_inst: List[int] = None,
) -> list:
    parts = parts or []
    spaces = []
    for b in benchmark:
        spaces_factory = BENCHMARK_MAPPER.get(b)
        spaces += spaces_factory(budget)

    spaces_to_run = []
    for curr_part in parts:
        part_num, num_of_parts = curr_part
        part_size = math.floor(len(spaces) / num_of_parts)
        spaces_to_run += spaces[part_num * part_size : (part_num + 1) * part_size]

    if func_num:
        spaces_to_run = [space for space in spaces_to_run if space.func_id in func_num]
    if func_dim:
        spaces_to_run = [space for space in spaces_to_run if space.dimension in func_dim]
    if func_inst:
        spaces_to_run = [space for space in spaces_to_run if space.func_instance in func_inst]

    return list(set(spaces_to_run))
