"""
Online trainer from 
https://github.com/recursionpharma/gflownet
"""

import copy
import os
import pathlib

import git
import torch
from omegaconf import OmegaConf
from torch import Tensor

from src.tacogfn.algo.trajectory_balance import TrajectoryBalance
from src.tacogfn.data.replay_buffer import ReplayBuffer
from src.tacogfn.data.replay_buffer import RewardPrioritizedReplayBuffer
from src.tacogfn.models.graph_transformer import GraphTransformerGFN

from .trainer import GFNTrainer


class StandardOnlineTrainer(GFNTrainer):
    def setup_model(self):
        self.model = GraphTransformerGFN(
            self.ctx,
            self.cfg,
            do_bck=self.cfg.algo.tb.do_parameterize_p_b,
        )

    def setup_algo(self):
        algo = self.cfg.algo.method
        if algo == "TB":
            algo = TrajectoryBalance
        else:
            raise ValueError(algo)
        self.algo = algo(self.env, self.ctx, self.rng, self.cfg)

    def setup_data(self):
        self.training_data = []
        self.test_data = []

    def setup(self):
        super().setup()
        self.offline_ratio = 0
        
        if self.cfg.replay.use:
            if self.cfg.replay.keep_top:
                self.replay_buffer = RewardPrioritizedReplayBuffer(self.cfg, self.rng)
            else:
                self.replay_buffer = ReplayBuffer(self.cfg, self.rng)
        else:
            self.replay_buffer = None

        # Separate Z parameters from non-Z to allow for LR decay on the former
        if hasattr(self.model, "logZ"):
            Z_params = list(self.model.logZ.parameters())
            non_Z_params = [
                i
                for i in self.model.parameters()
                if all(id(i) != id(j) for j in Z_params)
            ]
        else:
            Z_params = []
            non_Z_params = list(self.model.parameters())
        self.opt = torch.optim.Adam(
            non_Z_params,
            self.cfg.opt.learning_rate,
            (self.cfg.opt.momentum, 0.999),
            weight_decay=self.cfg.opt.weight_decay,
            eps=self.cfg.opt.adam_eps,
        )
        self.opt_Z = torch.optim.Adam(
            Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)
        )
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(
            self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)
        )
        self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(
            self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay)
        )

        self.sampling_tau = self.cfg.algo.sampling_tau
        if self.sampling_tau > 0:
            self.sampling_model = copy.deepcopy(self.model)
        else:
            self.sampling_model = self.model

        self.mb_size = self.cfg.algo.global_batch_size
        self.clip_grad_callback = {
            "value": lambda params: torch.nn.utils.clip_grad_value_(
                params, self.cfg.opt.clip_grad_param
            ),
            "norm": lambda params: torch.nn.utils.clip_grad_norm_(
                params, self.cfg.opt.clip_grad_param
            ),
            "none": lambda x: None,
        }[self.cfg.opt.clip_grad_type]

        # saving hyperparameters
        git_hash = git.Repo(
            __file__, search_parent_directories=True
        ).head.object.hexsha[:7]
        self.cfg.git_hash = git_hash

        os.makedirs(self.cfg.log_dir, exist_ok=True)
        print("\n\nHyperparameters:\n")
        yaml = OmegaConf.to_yaml(self.cfg)
        print(yaml)
        with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f:
            f.write(yaml)

    def step(self, loss: Tensor):
        loss.backward()
        for i in self.model.parameters():
            self.clip_grad_callback(i)
        self.opt.step()
        self.opt.zero_grad()
        self.opt_Z.step()
        self.opt_Z.zero_grad()
        self.lr_sched.step()
        self.lr_sched_Z.step()
        if self.sampling_tau > 0:
            for a, b in zip(self.model.parameters(), self.sampling_model.parameters()):
                b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau))
