import copy
import os
import pathlib
from itertools import chain

import gc
import git
import torch
from omegaconf import OmegaConf
from torch import Tensor

from gflownet.algo.advantage_actor_critic import A2C
from gflownet.algo.flow_matching import FlowMatching
from gflownet.algo.soft_q_learning import SoftQLearning
from gflownet.algo.trajectory_balance import TrajectoryBalance
from gflownet.data.replay_buffer import ReplayBuffer
from gflownet.models.graph_transformer import GraphTransformerGFN, GraphTransformer

from .trainer import GFNTrainer


class StandardOnlineTrainer(GFNTrainer):
    def setup_model(self):
        # Create a shared GraphTransformer instance
        
        # to be changed: number of shared layers 
        share_weight = False 
        num_shared_layers = 0
        if share_weight and num_shared_layers:
            self.shared_layers = GraphTransformer(
                x_dim=self.ctx.num_node_dim,
                e_dim=self.ctx.num_edge_dim,
                g_dim=self.ctx.num_cond_dim,
                num_emb=self.cfg.model.num_emb,
                num_layers=num_shared_layers,
                num_heads=self.cfg.model.graph_transformer.num_heads,
                ln_type=self.cfg.model.graph_transformer.ln_type,
            )
            
            self.model = GraphTransformerGFN(
                self.ctx,
                self.cfg,
                do_bck=self.cfg.algo.tb.do_parameterize_p_b,
                shared_num_layers=num_shared_layers,
                shared_layers=self.shared_layers
            )
        elif share_weight:
            self.shared_graph_transformer = GraphTransformer(
                x_dim=self.ctx.num_node_dim,
                e_dim=self.ctx.num_edge_dim,
                g_dim=self.ctx.num_cond_dim,
                num_emb=self.cfg.model.num_emb,
                num_layers=self.cfg.model.num_layers,
                num_heads=self.cfg.model.graph_transformer.num_heads,
                ln_type=self.cfg.model.graph_transformer.ln_type,
            )
            
            self.model = GraphTransformerGFN(
                self.ctx,
                self.cfg,
                do_bck=self.cfg.algo.tb.do_parameterize_p_b,
                shared_transf=self.shared_graph_transformer
            )
        else:
            self.model = GraphTransformerGFN(
                self.ctx,
                self.cfg,
                do_bck=self.cfg.algo.tb.do_parameterize_p_b,
            )

    def setup_algo(self):
        algo = self.cfg.algo.method
        if algo == "TB":
            algo = TrajectoryBalance
        elif algo == "FM":
            algo = FlowMatching
        elif algo == "A2C":
            algo = A2C
        elif algo == "SQL":
            algo = SoftQLearning
        else:
            raise ValueError(algo)
        self.algo = algo(self.env, self.ctx, self.rng, self.cfg)

    def setup_data(self):
        self.training_data = []
        self.test_data = []

    def _get_additional_parameters(self):
        return []

    def setup(self):
        super().setup()
        self.offline_ratio = 0
        self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None
        # Separate Z parameters from non-Z to allow for LR decay on the former
        if hasattr(self.model, "logZ"):
            Z_params = list(self.model.logZ.parameters())
            non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)]
        else:
            Z_params = []
            non_Z_params = list(self.model.parameters())

        import pdb 
        if self.cfg.algo.scale_temp: 
            # Temperature-Conditional GFlowNets
            temp_params = list(self.model.tempScale.parameters())
            non_temp_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in temp_params)]
            non_params = [i for i in non_temp_params if all(id(i) != id(j) for j in Z_params)]
            
            self.opt = torch.optim.Adam(
                chain(non_params),
                self.cfg.opt.learning_rate,
                (self.cfg.opt.momentum, 0.999),
                weight_decay=self.cfg.opt.weight_decay,
                eps=self.cfg.opt.adam_eps,
            )
            self.opt_temperature = torch.optim.Adam(temp_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999))
            self.lr_sched_temperature = torch.optim.lr_scheduler.LambdaLR(
                self.opt_temperature, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay)
            )
        else: 
            self.opt = torch.optim.Adam(
                chain(non_Z_params, self._get_additional_parameters()),
                self.cfg.opt.learning_rate,
                (self.cfg.opt.momentum, 0.999),
                weight_decay=self.cfg.opt.weight_decay,
                eps=self.cfg.opt.adam_eps,
            )
            
        self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999))
    
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay))
        self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(
            self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay)
        )

        self.sampling_tau = self.cfg.algo.sampling_tau
        if self.sampling_tau > 0:
            self.sampling_model = copy.deepcopy(self.model)
        else:
            self.sampling_model = self.model
        
        # self.second_model = self.sampling_model
        
        # self._get_additional_parameters = lambda: list(self.second_model.parameters()) 
        # Maybe only do this if we are using DDQN? 
        # self.second_model_lagged = copy.deepcopy(self.second_model)
        # self.second_model_lagged.to(self.device) 
        # self.dqn_tau = self.cfg.algo.dqn_tau
        # self.ddqn_update_step = self.cfg.algo.ddqn_update_step

        self.mb_size = self.cfg.algo.global_batch_size
        self.clip_grad_callback = {
            "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param),
            "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param),
            "none": lambda x: None,
        }[self.cfg.opt.clip_grad_type]

        # saving hyperparameters
        git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7]
        self.cfg.git_hash = git_hash

        os.makedirs(self.cfg.log_dir, exist_ok=True)
        print("\n\nHyperparameters:\n")
        yaml = OmegaConf.to_yaml(self.cfg)
        print(yaml)
        with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f:
            f.write(yaml)

    def step(self, loss: Tensor):
        loss.backward()
        for i in self.model.parameters():
            self.clip_grad_callback(i)
        self.opt.step()
        self.opt.zero_grad()
        self.opt_Z.step()
        self.opt_Z.zero_grad()
        self.lr_sched.step()
        self.lr_sched_Z.step()

        dev = self.ctx.device
        self.sampling_model = self.sampling_model.to(dev)
        self.model = self.model.to(dev)
        
        if self.cfg.algo.scale_temp:
            self.opt_temperature.step()
            self.lr_sched_temperature.step()
        
        if self.sampling_tau > 0:
            for a, b in zip(self.model.parameters(), self.sampling_model.parameters()):
                a = a.to(dev); b = b.to(dev)
                b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau))
