import torch 
import torch.nn as nn 
import torch.autograd as autograd 
import numpy as np 

import tqdm 
import wandb 

from itertools import chain 
from tqdm import trange 
from copy import deepcopy 

from abc import ABCMeta, abstractmethod 

class Environment(metaclass=ABCMeta):

    def __init__(self, batch_size: int, max_trajectory_length: int, log_reward: nn.Module, device: torch.device):
        self.batch_size = batch_size
        self.device = device 
        self.max_trajectory_length = max_trajectory_length
        self.batch_ids = torch.arange(self.batch_size, device=self.device) 
        self.traj_size = torch.ones((self.batch_size,), device=self.device)         
        self.stopped = torch.zeros((self.batch_size), device=self.device)
        self.is_initial = torch.ones((self.batch_size,), device=self.device)
        self._log_reward = log_reward

    @abstractmethod
    def apply(self, actions: torch.Tensor):
        pass

    @abstractmethod
    def backward(self, actions: torch.Tensor):
        pass

    @torch.no_grad()
    def log_reward(self):
        return self._log_reward(self)
    
    @torch.no_grad() 
    def merge(self, batch_state): 
        self.batch_ids = torch.hstack([self.batch_ids, batch_state.batch_ids]) 
        self.batch_size += batch_state.batch_size 
        self.stopped = torch.hstack([self.stopped, batch_state.stopped]) 
        self.is_initial = torch.hstack([self.is_initial, batch_state.is_initial]) 
    
    @property 
    def unique_input(self): 
        raise NotImplementedError 

class ForwardPolicyMeta(nn.Module, metaclass=ABCMeta): 
    
    masked_value = -1e5 

    def __init__(self, eps=.3, device='cpu'): 
        super(ForwardPolicyMeta, self).__init__() 
        self.eps = eps 
        self.seed = None
        self.device = device 

    @abstractmethod 
    def get_latent_emb(self): 
        pass 

    @abstractmethod 
    def get_pol(self): 
        pass 

    def set_seed(self, seed): 
        self.seed = seed 
    
    def unset_seed(self): 
        self.seed = None 

    def get_actions(self, pol, mask=None): 
        if mask is None: 
            uniform_pol = torch.ones_like(pol) 
        else: 
            uniform_pol = torch.where(mask==1., 1., 0.)
            uniform_pol = uniform_pol / uniform_pol.sum(dim=1, keepdims=True)  

        eps = 0. if not self.training else self.eps 
        exp_pol = pol * (1 - eps) + eps * uniform_pol 
        
        if self.seed is not None: 
            g = torch.Generator(device=self.device) 
            g.manual_seed(self.seed) 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True, generator=g) 
        else: 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True) 
        actions = actions.squeeze(dim=-1) 
        return actions, exp_pol  

    def forward(self, batch_state, actions=None):
        if not hasattr(batch_state, 'forward_mask'): 
            batch_state.forward_mask = None  
        latent_emb = self.get_latent_emb(batch_state) 
        pol, gflows = self.get_pol(latent_emb, batch_state.forward_mask) 
        if actions is None: actions, _ = self.get_actions(pol, batch_state.forward_mask) 
        return actions, torch.log(pol[batch_state.batch_ids, actions])

class LogRewardProduct(nn.Module): 

    def __init__(self, rewards): 
        super(LogRewardProduct, self).__init__() 
        self.reward_func_lst = rewards 
    
    @torch.no_grad() 
    def forward(self, batch_state): 
        log_reward = torch.zeros((batch_state.batch_size), device=batch_state.device) 
        for reward_func in self.reward_func_lst: 
            log_reward += reward_func(batch_state) 
        return log_reward

@torch.no_grad()
def reset_params(m):
    for param in m.parameters():
        nn.init.uniform_(param, -5e-2, 5e-2)

@torch.no_grad() 
def sample_massive_batch(gflownet, create_env, num_batches, num_back_traj=1, use_progress_bar=True): 
    env = create_env() 
    env = gflownet.sample(env) 
    marginal_log = gflownet.sample_many_backward(env, num_trajectories=num_back_traj) 
    log_rewards = env.log_reward() 

    for _ in tqdm.trange(num_batches - 1, disable=not use_progress_bar): 
        env_i = create_env()
        env_i = gflownet.sample(env_i) 
        env.merge(env_i) 
        marginal_log = torch.vstack([marginal_log, gflownet.sample_many_backward(
                            env_i, num_trajectories=num_back_traj)])
        log_rewards = torch.hstack([log_rewards, env_i.log_reward()]) 

    return env, marginal_log, log_rewards   

def unique(x, dim=-1):
    values, inverse, counts = torch.unique(x, return_inverse=True, return_counts=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=values.device)
    # inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return values, inverse, counts, inverse.new_empty(values.size(dim)).scatter_(dim, inverse, perm)

def marginal_dist(env, marginal_log, log_rewards, dim=-1):
    values, inverse, counts, indices = unique(env.unique_input, dim=dim)
    # Compute learned distribution 
    marginal_log_batch = torch.zeros((values.size(0), marginal_log.shape[1]), device=values.device)
    marginal_log_batch.scatter_add_(dim=0, index=inverse.view(-1, 1), src=marginal_log.exp())
    marginal_log_batch = marginal_log_batch.sum(dim=-1)
    marginal_log_batch /= (counts * marginal_log.shape[1]) 
    learned_dist = marginal_log_batch / marginal_log_batch.sum()  
    # Compute the target distribution 
    target_dist = (log_rewards[indices] - torch.logsumexp(log_rewards[indices], dim=0)).exp() 
    return learned_dist, target_dist 

def compute_marginal_dist(gflownet, create_env, num_batches, num_back_traj, use_progress_bar=False): 
    # Sample from the learned distribution 
    samples, marginal_log, log_rewards = sample_massive_batch(gflownet, create_env, num_batches, 
                                                             num_back_traj, use_progress_bar) 
    return marginal_dist(samples, marginal_log, log_rewards, dim=0) 

def generic_control_variate(loss_func, score_func, model): 
    # First term 
    grad_loss = autograd.grad(loss_func.mean(), model.parameters(), retain_graph=True, allow_unused=True)
    grad_log_prob = autograd.grad(score_func.mean(), model.parameters(), retain_graph=True, allow_unused=True) 

    # Second term  
    batch_size = loss_func.shape[0]
    I = torch.eye(batch_size, device=loss_func.device)  
    est = loss_func - (1 - I) @ (loss_func) / (batch_size - 1) 
    grad_reinforce = autograd.grad((est.detach() * score_func).mean(), model.parameters(), allow_unused=True)

    return grad_loss, grad_log_prob, grad_reinforce, model.parameters() 

def backward_func(loss, score_func, gflownet):
    grad_loss, grad_log_prob, grad_reinforce, p = generic_control_variate(loss, score_func, gflownet.pf) 
    for p, gl, glp, gr in zip(p, grad_loss, grad_log_prob, grad_reinforce): 
        if gl is None or glp is None: continue 
        a = (gl * glp).sum() / (glp * glp).sum() 
        gl = gl - a * glp
        p.grad = gl + gr 

def train_step(
    gflownet, 
    create_env, 
    epochs, 
    optimizer, 
    config, 
    previous_model=None, 
    scheduler=None, 
    use_progress_bar=False, 
    fix_gradient=False
): 
    pbar = tqdm.trange(epochs, disable=not use_progress_bar) 
    losses = list() 
    kwargs = {'previous_model': previous_model} if previous_model is not None else dict() 
    for _ in pbar: 
        optimizer.zero_grad() 
        env = create_env() 
        if fix_gradient: 
            loss, score_func = gflownet(env, **kwargs) 
            backward_func(loss, score_func, gflownet)
            loss = loss.mean() 
        else: 
            loss = gflownet(env, **kwargs) 
            loss.backward() 

        pbar.set_postfix(loss=loss.cpu().item()) 
        optimizer.step() 
        if scheduler is not None: 
            scheduler.step() 
            
        wandb.log({'loss': loss})  
        losses.append(loss.cpu().item()) 
    return losses 
 
def train_gfn(gfn, create_env, optimizer, scheduler, config, previous_model=None): 

    train_args = {
        'gflownet': gfn, 
        'create_env': create_env, 
        'epochs': config.epochs_per_step, 
        'optimizer': optimizer, 
        'scheduler': scheduler, 
        'use_progress_bar': config.use_progress_bar,
        'config': config, 
        'previous_model': previous_model 
    }

    match config.criterion: 
        case 'kl': 
            with gfn.on_policy(): 
                train_step(**train_args, fix_gradient=True) 
        case 'tb': 
            train_step(**train_args, fix_gradient=False) 
        case _: 
            raise Exception 
    
    return gfn  
