from typing import Any, Dict

from tasks.task_bff import ThreeBitFF
from tasks.task_seq import DelayedDiscrimination
from tasks.task_sine import SineWaveDataset
from tasks.task_pathint import PathIntegration
from network import Network, Network_muP


def get_task_params(task: str, args: Any, task_idx: int) -> Dict[str, Any]:
    base_params = {
        'task_name': args.task_name,
        'n_hidden': args.n_hidden,
        'rnn_type': args.rnn_type,
        'muP_param': args.muP_param,
        'tau': args.tau,
        'gain': args.gain,
        'gamma': args.gamma,
        'verbose': False,
        'seed': args.seed,
        'lr': args.lr,
        'device': args.device,
        'n_batch': args.n_batch,
        'save_path': args.save_path,
        'lr_scheduler': args.lr_scheduler,
        'W_rank_reg': args.W_rank_reg,
        'W_l1_reg': args.W_l1_reg,
        'trainable_ratio': args.trainable_ratio,
        'input_reproducing': args.input_reproducing,
        'small_weight_init': args.small_weight_init,
        'init_sigma': args.init_sigma,
        'weight_decay': args.weight_decay,
        'aux_loss': args.aux_loss,
        'time_shuffled': args.time_shuffled,
    }

    if 'bff' in task:
        return {
            **base_params,
            'n_bits': args.n_bits,
            'p_flip': args.p_flip,
            'n_time': args.n_time,
            'dim_input': args.n_bits,
            'dim_output': args.n_bits,
            'n_fps': args.n_fps
        }

    if 'delayed_discrimination' in task:
        if isinstance(args.dim_input, list):
            dim_input = args.dim_input[task_idx]
            dim_output = args.dim_output[task_idx]
        else:
            dim_input = args.dim_input
            dim_output = args.dim_output

        return {
            **base_params,
            'dim_input': dim_input,
            'dim_output': dim_output,
            'low_ts': args.dd_low_ts,
            'high_ts': args.dd_high_ts,
            'max_delay': args.dd_max_delay,
            'increment': args.dd_increment,
            'target': args.dd_target,
            'num_samples': args.num_samples
        }

    if 'sinewave' in task:
        return {
            **base_params,
            'num_channels': args.num_channels,
            'dim_input': args.num_channels,
            'dim_output': args.num_channels,
            'freq_range': [args.freq_min, args.freq_max],
            'freq_num': args.freq_num,
            'dt': args.dt,
            'sequence_length': args.sequence_length,
            'num_samples': args.num_samples
        }

    if 'path_integration' in task:
        return {
            **base_params,
            'n_trials': args.n_trials,
            'n_timesteps': args.n_timesteps,
            'v_max': args.v_max,
            'theta_std': args.theta_std,
            'phi_std': args.phi_std,
            'speed_std': args.speed_std,
            'x_noise_std': args.x_noise_std,
            'xy_noise_std': args.xy_noise_std,
            'xyz_noise_std': args.xyz_noise_std,
            'stop_mean_duration': args.stop_mean_duration,
            'go_mean_duration': args.go_mean_duration,
            'environment_size': args.environment_size,
            'dim': args.path_dim,
            'dim_input': args.path_dim,
            'dim_output': args.path_dim
        }

    raise ValueError(f"Task '{task}' is not implemented.")


def construct_task(
    task: str,
    args: Any,
    task_idx: int = 0,
    pretrained_model: Any = None
) -> Any:
    params = get_task_params(task, args, task_idx)

    if 'bff' in task:
        task_obj = ThreeBitFF(params)
    elif 'delayed_discrimination' in task:
        task_obj = DelayedDiscrimination(params)
    elif 'sinewave' in task:
        task_obj = SineWaveDataset(params)
    elif 'path_integration' in task:
        task_obj = PathIntegration(params)
    else:
        raise ValueError(f"Task '{task}' is not implemented.")

    if params['muP_param']:
        task_obj.model = Network_muP(params)
        print("Using muP parameterization")
    else:
        task_obj.model = Network(params)
        print("Using vanilla RNN")

    task_obj.model.savename = params['task_name']
    return task_obj
