import logging
import os
import time
from typing import Callable

import torch
from beam import task
from beam.concurrent import BeamParallel
from beam.distributed import RayDispatcher

RESULT_NAME = "results"
PARAM_NAME = "params"


def construct_assign_params_code(
    param_name: str, params, activation_function_name: str
):
    set_params = f"{PARAM_NAME} = {params}"
    return (
        f"{set_params}{os.linesep}"
        f"{param_name}={activation_function_name}({'*' if isinstance(params, tuple) else ''}{params})"
    )


class ResultAccuracy:
    def __init__(self, measure_distance: Callable, max_loss: float, logger=None):
        self.measure_distance = measure_distance
        self.max_loss = max_loss
        self.logger = logger or logging.getLogger(__name__)

    def __call__(self, llm_code: str):
        try:
            exec(llm_code)
            code_res = locals().get(RESULT_NAME)
            param = locals().get(PARAM_NAME)
            return min(self.measure_distance(param, code_res), self.max_loss)
        except Exception:
            self.logger.exception(f"Error in code {llm_code}")
            return self.max_loss


class SpeedChecker:
    def __init__(self, max_loss: float):
        self.max_loss = max_loss

    def __call__(self, llm_code: str):
        try:
            start_time = time.perf_counter()
            exec(llm_code)
            resu = locals().get(RESULT_NAME)
            end_time = time.perf_counter()
            return end_time - start_time
        except Exception as e:
            return self.max_loss


class LengthChecker:
    def __init__(self, expected_length: int):
        self.expected_length = expected_length

    def __call__(self, code: str):
        code_len = len([w for w in code.split() if w])
        return ((code_len - self.expected_length) ** 2) / 2


# Todo - use parameters
class MultipleParamsChecker:
    def __init__(self, checker, params):
        self.checker = checker
        self.params = params

    def __call__(self, llm_code: str):
        loss = 0
        for param in self.params:
            loss += self.checker(llm_code)
        return loss


class MultipleCheckers:
    def __init__(self, checkers):
        self.checkers = checkers

    def __call__(self, llm_code: str):
        loss = 0
        for checker, weight in self.checkers:
            loss += weight * checker(llm_code)
        return loss


class CheckerForBatch:
    def __init__(self, checker):
        self.checker = checker

    def __call__(self, llm_code: list):
        tasks = [task(self.checker)(code) for code in llm_code]
        bp = BeamParallel(func=self.checker, n_workers=20, method="threading")
        res = bp(tasks)
        results = [r["result"] for r in res.results]
        return torch.tensor(results)


class CheckerForSampler:
    def __init__(self, checker):
        self.checker = checker

    def __call__(self, llm_code: list):
        return torch.stack([self.checker(code) for code in llm_code]).mean(dim=0)


class LoopSamplesChecker:
    def __init__(self, checker, should_mean=False, *args, **kwargs):
        self.checker = checker
        self.should_mean = should_mean

    def __call__(self, llm_code: list):
        results = [self.checker(code) for code in llm_code]
        values = (
            torch.tensor(results).mean(dim=0)
            if self.should_mean
            else torch.stack(results)
        )
        return values


class RaySamplesChecker:
    def __init__(self, checker, should_mean=False, concurrency=4):
        self.checker = checker
        self.should_mean = should_mean
        self.concurrency = concurrency

    def __call__(self, llm_code: list):
        workers = [
            RayDispatcher(
                self.checker, remote_kwargs={"num_cpus": 1 / self.concurrency}
            )
            for _ in llm_code
        ]
        results = [worker(code) for code, worker in zip(llm_code, workers)]
        results_values = [res.value for res in results]
        values = (
            torch.tensor(results_values).float().mean(dim=0)
            if self.should_mean
            else torch.stack(results_values)
        )
        return values


class RayMultipleParamsChecker:
    def __init__(self, checker, params, activation_function_name, concurrency=4):
        self.checker = checker
        self.params = params
        self.concurrency = concurrency
        self.activation_function_name = activation_function_name

    def __call__(self, llm_code: str):
        workers = [
            RayDispatcher(
                self.checker, remote_kwargs={"num_cpus": 1 / self.concurrency}
            )
            for _ in llm_code
        ]

        def create_code_with_activation(param):
            return os.linesep.join(
                [
                    llm_code,
                    construct_assign_params_code(
                        RESULT_NAME, param, self.activation_function_name
                    ),
                ]
            )

        results = [
            worker(create_code_with_activation(param))
            for worker, param in zip(workers, self.params)
        ]
        results_values = [res.wait(5) for res in results]
        loss = 0
        for value, res in zip(results_values, results):
            if value is None:
                res.kill()
                loss += self.checker.max_loss
            else:
                loss += value
        return loss / len(results_values) if results_values else self.checker.max_loss
