import functools
import math
from pathlib import Path
from uuid import uuid4

import torch
import transformers
from torch.nn import Linear, Sequential, SmoothL1Loss
from transformers import AutoTokenizer

from algorithms.convergence_algorithms.egl import EGL
from algorithms.mapping.trust_region import TanhTrustRegion
from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
from algorithms.nn.datasets import PairsInEpsRangeDataset
from algorithms.nn.modules import BigLinearNetwork, BaseSequentialModel
from applications.checkers import (
    SpeedChecker,
    LengthChecker,
    MultipleCheckers,
    ResultAccuracy,
    RaySamplesChecker,
    RayMultipleParamsChecker,
)
from applications.code_space import LLMPythonGeneratorSpace
from applications.common import (
    FUNCTIONS,
    remove_not_function_elements,
    MODELS,
    add_function_prompt,
)
from applications.models_wrapper import (
    TransformerGeneratorWrapper,
    SoftPromptTuningWrapper,
    LoraEmbeddingWrapper,
    ExtractCodeWrapper,
    MultipleSoftTuningLora,
    LoraEmbedLayerWrapper,
)
from applications.saver import SaverCallbackHandler, GenerateCode
from handlers.drawer_handlers import LoggerDrawerHandler
from handlers.drawers.loss_drawer import StepSizeDrawer
from handlers.drawers.utils import convert_to_real_drawer
from utils.logger import create_logger
from utils.python import timestamp_file_signature

losses = {
    "speed": lambda func_data: RayMultipleParamsChecker(
        SpeedChecker(func_data.max_lost),
        func_data.params,
        func_data.activation_func_name,
    ),
    "length": lambda func_data: LengthChecker(250),
}
parameter_type = {
    "soft_tune": SoftPromptTuningWrapper,
    "embed": LoraEmbeddingWrapper,
    "soft_lora": MultipleSoftTuningLora,
    "embed_lora": LoraEmbedLayerWrapper,
}


def find_tag_index_in_tokens(tokenizer, tokens, tag):
    return [i for i, token in enumerate(tokens) if tag in tokenizer.decode(token)]


def model_from_pipeline(
    pipeline, tokenizer, function_data, model_data, parameters, samples=10
):
    embeddings = pipeline.model.model.embed_tokens
    tokens = tokenizer.encode(
        function_data.prompt, return_tensors="pt", add_special_tokens=False
    )
    index_for_start_text, index_for_end_text = find_tag_index_in_tokens(
        tokenizer, tokens[0], '"""'
    )
    index_for_start_text += 1
    word_embeddings = embeddings(tokens)
    parameters_type = parameter_type[parameters]

    model = parameters_type(
        model=TransformerGeneratorWrapper(
            pipeline.model,
            tokenizer,
            samples=samples,
        ),
        embeddings=embeddings,
        inputs=word_embeddings,
        tokens=tokens,
        start_idx=index_for_start_text,
        tuning_len=index_for_end_text - index_for_start_text,
        rank=4,
    )
    model = ExtractCodeWrapper(
        function_data.post_code_creator_processor,
        ExtractCodeWrapper(
            remove_not_function_elements,
            ExtractCodeWrapper(
                functools.partial(model_data.extract_code, function_data),
                ExtractCodeWrapper(
                    functools.partial(add_function_prompt, function_data), model
                ),
            ),
        ),
    )
    return model


def main(model_name: str, loss: str, parameters: str, func: str, device):
    dtype = torch.float32
    base_dir = Path(__file__).parent.parent
    run_name = "basic"
    algorithm_name = "fine_tuning_len"
    results_save_path = base_dir / "saves" / f"{timestamp_file_signature()}-{uuid4()}"
    logs_path = base_dir / "app_logs"
    normal_logs_path = (
        logs_path
        / rf"logs_for_parallel-{algorithm_name}-{run_name}-{timestamp_file_signature()}"
    )
    error_logs_path = (
        logs_path
        / rf"error_logs-{algorithm_name}-{run_name}-{timestamp_file_signature()}"
    )
    function_data = FUNCTIONS[func]
    model_data = MODELS[model_name]
    logger = create_logger(
        normal_logs_path, None, run_name, algorithm_name, None, "normal-logs"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_data.model_name)
    pipeline = transformers.pipeline(
        "text-generation",
        model=model_data.model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        temperature=0.7,
        top_k=50,
        do_sample=True,
        trust_remote_code=True,
    )
    pipeline.model.eval()
    function_data.prompt = model_data.manipulate_input(function_data.prompt)
    model = model_from_pipeline(
        pipeline, tokenizer, function_data, model_data, parameters
    )
    dims = model.numel()
    model_to_train = BaseSequentialModel(
        Sequential(Linear(dims, 1, bias=False, dtype=dtype)).to(device=device)
    )
    grad_network = BigLinearNetwork(dims, [dims // 2, dims // 2, dims], device).to(
        dtype=dtype
    )
    grad_opt = torch.optim.Adam(
        grad_network.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-04
    )
    model_opt = torch.optim.Adam(
        model_to_train.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-04
    )

    space = LLMPythonGeneratorSpace(
        model,
        MultipleCheckers(
            [
                (
                    RaySamplesChecker(
                        RaySamplesChecker(losses[loss](function_data), True, 6)
                    ),
                    0.5,
                ),
                (
                    RaySamplesChecker(
                        RaySamplesChecker(
                            RayMultipleParamsChecker(
                                ResultAccuracy(
                                    function_data.loss,
                                    function_data.max_lost,
                                    create_logger(
                                        error_logs_path,
                                        None,
                                        run_name,
                                        algorithm_name,
                                        None,
                                        "error_logs",
                                    ),
                                ),
                                function_data.params,
                                function_data.activation_func_name,
                            ),
                            True,
                            6,
                        )
                    ),
                    0.5,
                ),
            ]
        ),
        upper_bound=torch.ones(dims),
        lower_bound=-torch.ones(dims),
        logger=logger,
        budget=150_000,
    )

    egl = EGL(
        space,
        helper_network=grad_network,
        model_to_train=model_to_train,
        value_optimizer=grad_opt,
        model_to_train_optimizer=model_opt,
        epsilon=0.1 * math.sqrt(dims),
        epsilon_factor=0.97,
        min_epsilon=1e-4,
        perturb=0,
        grad_loss=SmoothL1Loss(),
        database_type=PairsInEpsRangeDataset,
        database_size=120_000,
        input_mapping=TanhTrustRegion(
            space.upper_bound,
            space.lower_bound,
            min_trust_region_size=0,
            dtype=dtype,
        ),
        output_mapping=AdaptedOutputUnconstrainedMapping(output_epsilon=5e-4),
        dtype=dtype,
        device=device,
        logger=logger,
    )
    logger.info(f"starting with {egl}")

    saver_callback_handler = SaverCallbackHandler(
        results_save_path,
        GenerateCode(
            model_from_pipeline(
                pipeline, tokenizer, function_data, model_data, parameters, samples=1
            )
        ),
    )
    try:
        egl.train(
            epochs=10_000,
            exploration_size=32,
            num_loop_without_improvement=20,
            min_iteration_before_shrink=60,
            helper_model_training_epochs=1,
            callback_handlers=[
                LoggerDrawerHandler(
                    convert_to_real_drawer(StepSizeDrawer()),
                    logger=logger,
                    name=f"step size {space}",
                ),
                saver_callback_handler,
            ],
        )
    except Exception as e:
        logger.exception(f"Train has stopped")
    logger.info("Finish running")
