import numpy as np
from omegaconf import OmegaConf
import pathlib
import random
import torch

import config
from models import *
from tasks import *
import trainer

args = config.load_train_config()

print("Experiment configuration:")
print(OmegaConf.to_yaml(args, resolve=True, sort_keys=True))

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

pathlib.Path(args.path).mkdir(parents=True, exist_ok=True)

if "spatial_navigation" in args.config:
    task = spatial_navigation.SpatialNavigation(
        box_width=args.task.box_width,
        box_height=args.task.box_height,
        border_region=args.task.border_region,
        border_slow_factor=args.task.border_slow_factor,
        init_pos=args.task.init_pos,
        biased=args.task.biased,
        biased_ratio=args.task.biased_ratio,
        drift_const=args.task.drift_const,
        anchor_point=np.array(args.task.anchor_point),
        dt=args.task.dt,
        sigma=args.task.sigma,
        b=args.task.b,
        mu=args.task.mu,
        use_place_cells=args.task.use_place_cells,
        place_cells_num=args.task.place_cells_num,
        place_cells_sigma=args.task.place_cells_sigma,
        place_cells_surround_scale=args.task.place_cells_surround_scale,
        place_cells_dog=args.task.place_cells_dog,
        sequence_length=args.task.sequence_length,
        batch_size=args.task.batch_size
    )
elif "lorenz_system" in args.config:
    task = lorenz_system.LorenzSystem(
        init_xyz=args.task.init_xyz,
        init_param=args.task.init_param,
        dt=args.task.dt,
        sigma=args.task.sigma,
        rho=args.task.rho,
        beta=args.task.beta,
        sequence_length=args.task.sequence_length,
        batch_size=args.task.batch_size
    )
elif "head_direction" in args.config:
    task = head_direction.HeadDirection(
        dimensionality=args.task.dimensionality,
        init_hd=args.task.init_hd,
        biased=args.task.biased,
        drift_const=args.task.drift_const,
        anchor_angle=args.task.anchor_angle,
        dt=args.task.dt,
        sigma=args.task.sigma,
        mu=args.task.mu,
        use_hd_cells=args.task.use_hd_cells,
        hd_cells_num=args.task.hd_cells_num,
        hd_cells_angular_spread=args.task.hd_cells_angular_spread,
        sequence_length=args.task.sequence_length,
        batch_size=args.task.batch_size
    )

train_data_generator = task.get_generator()
test_data = task.get_test_batch()

model = rnn.RNN(
    task=task,
    n_in=args.rnn.n_in,
    n_rec=args.rnn.n_rec,
    n_out=args.rnn.n_out,
    n_init=args.rnn.n_init,
    sigma_in=np.sqrt(args.rnn.sigma2_in),
    sigma_rec=np.sqrt(args.rnn.sigma2_rec),
    sigma_out=np.sqrt(args.rnn.sigma2_out),
    dt=args.rnn.dt,
    tau=args.rnn.tau,
    feedback_freq=args.rnn.feedback_freq,
    bias=args.rnn.bias,
    activation_fn=args.rnn.activation_fn,
    device=args.device
)
model.to(args.device)
model_trainer = trainer.Trainer(
    model=model,
    train_data=train_data_generator,
    test_data=test_data,
    n_epochs=args.trainer.n_epochs,
    lr=args.trainer.lr,
    weight_decay=args.trainer.weight_decay,
    feedback=args.trainer.feedback,
    feedback_freq=args.rnn.feedback_freq,
    compute_all_metrics=args.trainer.compute_all_metrics,
    record_errors=args.trainer.record_errors,
    test_freq=args.trainer.test_freq,
    save_freq=args.trainer.save_freq,
    path=args.path,
    device=args.device
)
model_trainer.train()
