import torch 
import torch.nn as nn
import torch.autograd as autograd  
import os 
import tqdm 
import wandb 
import numpy as np 

from abc import ABCMeta, abstractmethod 

from itertools import chain 

from var_red_gfn.gradients import gradient_kl_div, gradient_renyi_div, \
    gradient_tsallis_div, gradient_rev_kl_div 

def compute_gradients(gflownet, traj_stats, loss): 

    match gflownet.criterion: 
        case 'tb' | 'db' | 'subtb' | 'dbc': 
            grads = autograd.grad(loss, gflownet.pf.parameters(), retain_graph=True) 
            for p, grad in zip(gflownet.pf.parameters(), grads): 
                p.grad = grad 
            grad_log_z, = autograd.grad(loss, gflownet.log_z, allow_unused=True, retain_graph=True) 
            gflownet.log_z.grad = grad_log_z 
        case 'kl' | 'jeffrey': 
            gradient_kl_div(gflownet, loss, traj_stats) 
        case 'rev_kl': 
            gradient_rev_kl_div(gflownet, loss, traj_stats) 
        case 'renyi': 
            gradient_renyi_div(gflownet, loss, traj_stats) 
        case 'tsallis': 
            gradient_tsallis_div(gflownet, loss, traj_stats) 
        case _: 
            raise Exception(f'invalid criterion: {gflownet.criterion}') 

def train(gflownet, create_env, epochs, use_scheduler=False, use_progress_bar=False): 
    optimizer = torch.optim.Adam([
        {'params': chain(gflownet.pf.parameters(), gflownet.pb.parameters()), 'lr': 1e-3}, 
        {'params': gflownet.log_z, 'lr': 1e1} 
    ])
    if use_scheduler: 
        scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=epochs, power=1.)
    else: 
        scheduler = None 
    
    losses = train_step(gflownet, create_env, epochs, optimizer, scheduler, use_progress_bar) 
    return gflownet, losses 

def train_step(gflownet, create_env, epochs, optimizer, scheduler=None, use_progress_bar=False): 
    pbar = tqdm.trange(epochs, disable=not use_progress_bar) 
    losses = list() 
    for _ in pbar: 
        optimizer.zero_grad() 
        env = create_env() 
        loss, traj_stats = gflownet(env) 
        compute_gradients(gflownet, traj_stats, loss)
        pbar.set_postfix(loss=loss.mean()) 
        optimizer.step() 
        if scheduler is not None: 
            scheduler.step() 
        wandb.log({'loss': loss})  
        losses.append(loss.mean().cpu().item()) 
    return losses 
   
class Environment(metaclass=ABCMeta):

    def __init__(self, batch_size: int, max_trajectory_length: int, log_reward: nn.Module, device: torch.device):
        self.batch_size = batch_size
        self.device = device 
        self.max_trajectory_length = max_trajectory_length
        self.batch_ids = torch.arange(self.batch_size, device=self.device) 
        self.traj_size = torch.ones((self.batch_size,), device=self.device)         
        self.stopped = torch.zeros((self.batch_size), device=self.device)
        self.is_initial = torch.ones((self.batch_size,), device=self.device)
        self._log_reward = log_reward

    @abstractmethod
    def apply(self, actions: torch.Tensor):
        pass

    @abstractmethod
    def backward(self, actions: torch.Tensor):
        pass

    @torch.no_grad()
    def log_reward(self):
        return self._log_reward(self)
    
    @torch.no_grad() 
    def merge(self, batch_state): 
        self.batch_ids = torch.hstack([self.batch_ids, batch_state.batch_ids]) 
        self.batch_size += batch_state.batch_size 
        self.stopped = torch.hstack([self.stopped, batch_state.stopped]) 
        self.is_initial = torch.hstack([self.is_initial, batch_state.is_initial]) 
    
    @property 
    def unique_input(self): 
        raise NotImplementedError 

class ForwardPolicyMeta(nn.Module, metaclass=ABCMeta): 
    
    masked_value = -1e5 

    def __init__(self, eps=.3, device='cpu'): 
        super(ForwardPolicyMeta, self).__init__() 
        self.eps = eps 
        self.seed = None
        self.device = device 

    @abstractmethod 
    def get_latent_emb(self): 
        pass 

    @abstractmethod 
    def get_pol(self): 
        pass 

    def set_seed(self, seed): 
        self.seed = seed 
    
    def unset_seed(self): 
        self.seed = None 

    def get_actions(self, pol, mask=None): 
        if mask is None: 
            uniform_pol = torch.ones_like(pol) 
        else: 
            uniform_pol = torch.where(mask==1., 1., 0.)
            uniform_pol = uniform_pol / uniform_pol.sum(dim=1, keepdims=True)  

        eps = 0. if not self.training else self.eps 
        exp_pol = pol * (1 - eps) + eps * uniform_pol 
        
        if self.seed is not None: 
            g = torch.Generator(device=self.device) 
            g.manual_seed(self.seed) 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True, generator=g) 
        else: 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True) 
        actions = actions.squeeze(dim=-1) 
        return actions, exp_pol  

    def forward(self, batch_state, actions=None, return_ps=False):
        if not hasattr(batch_state, 'forward_mask'): 
            batch_state.forward_mask = None  
        latent_emb = self.get_latent_emb(batch_state) 
        pol, gflows = self.get_pol(latent_emb, batch_state.forward_mask) 
        spol = pol # sampling policy  
        if actions is None: actions, spol = self.get_actions(pol, batch_state.forward_mask) 
        if return_ps: 
            return actions, \
                torch.log(pol[batch_state.batch_ids, actions]), \
                torch.log(pol[batch_state.batch_ids, -1]), \
                    torch.log(spol[batch_state.batch_ids, actions]) 
        else: 
            return actions, \
                torch.log(pol[batch_state.batch_ids, actions]), \
                gflows, \
                torch.log(spol[batch_state.batch_ids, actions]) 

class ForwardPolicyProduct(ForwardPolicyMeta): 

    def __init__(self, gflownets, eps=.3, device='cpu'): 
        super(ForwardPolicyProduct, self).__init__(eps=eps)  
        self.device = device 
        self.policies = [gflownet.pf for gflownet in gflownets] 

    def get_latent_emb(self, batch_state): 
        latent_emb_lst = list() 
        for policy in self.policies: 
            latent_emb = policy.get_latent_emb(batch_state) 
            latent_emb_lst.append(latent_emb) 
        return latent_emb_lst 
    
    def get_pol(self, latent_emb_lst, mask): 
        pol_lst = list()  
        for policy, latent_emb in zip(self.policies, latent_emb_lst): 
            pol, _ = policy.get_pol(latent_emb, mask) 
            pol_lst.append(pol.unsqueeze(-1)) 
        pol = torch.cat(pol_lst, dim=-1).prod(dim=-1)
        pol /= pol.sum(dim=1, keepdims=True) 
        return pol   

    def forward(self, batch_state, actions=None): 
        latent_emb_lst = self.get_latent_emb(batch_state) 
        pol = self.get_pol(latent_emb_lst, batch_state.forward_mask) 
        if actions is None: actions = self.get_actions(pol, batch_state.forward_mask) 
        return actions, torch.log(pol[batch_state.batch_ids, actions]) 
    
def log_artifact_tensor(tensor, artifact_name): 
    torch.save(tensor, artifact_name)  
    artifact = wandb.Artifact(artifact_name, type='tensor') 
    artifact.add_file(artifact_name) 
    wandb.run.log_artifact(artifact) 
    
def load_artifact_tensor(artifact_name): 
    artifact = wandb.run.use_artifact(f'weekday/{os.environ["WANDB_PROJECT_NAME"]}/{artifact_name}:latest', 
                                      type='tensor') 
    artifact_dir = artifact.download() 
    return torch.load(os.path.join(artifact_dir, artifact_name)) 

def log_artifact_module(module, artifact_name, artifact_filename): 
    torch.save(module.state_dict(), artifact_filename) 
    artifact = wandb.Artifact(artifact_name, type='model') 
    artifact.add_file(artifact_filename) 
    wandb.run.log_artifact(artifact) 

def load_artifact_module(module, artifact_name, artifact_filename): 
    artifact = wandb.run.use_artifact(f'weekday/{os.environ["WANDB_PROJECT_NAME"]}/{artifact_name}:latest', 
                                        type='model') 
    artifact_dir = artifact.download() 
    module.load_state_dict(torch.load(os.path.join(artifact_dir, 
                                                   os.path.basename(artifact_filename)))) 
    return module 

@torch.no_grad() 
def init_logz(gflownet, create_env, num_batches, batch_size): 
    log_rewards = list() 
    marginal_log_dist = list() 
    for _ in tqdm.tqdm(range(num_batches)): 
        env = create_env() 
        env = gflownet.sample(env) 
        log_rewards.append(env.log_reward()) 
        marginal_log_prob = gflownet.sample_many_backward(env, num_trajectories=8)
        marginal_prob = torch.logsumexp(marginal_log_prob, dim=1) 
        marginal_prob -= torch.log(torch.tensor(marginal_log_prob.shape[1]))  
        marginal_log_dist.append(marginal_prob) 
    log_rewards = torch.hstack(log_rewards) 
    marginal_log_dist = torch.hstack(marginal_log_dist) 
    gflownet.log_z.fill_(torch.logsumexp(log_rewards - marginal_log_dist, dim=0))
    gflownet.log_z.add_(-torch.log(torch.tensor(len(log_rewards))))  

    assert gflownet.log_z.requires_grad 

@torch.no_grad() 
def sample_massive_batch(gflownet, create_env, num_batches, num_back_traj=1, use_progress_bar=True): 
    env = create_env() 
    env = gflownet.sample(env) 
    marginal_log = gflownet.sample_many_backward(env, num_trajectories=num_back_traj) 
    log_rewards = env.log_reward() 

    for _ in tqdm.trange(num_batches - 1, disable=not use_progress_bar): 
        env_i = create_env()
        env_i = gflownet.sample(env_i) 
        env.merge(env_i) 
        marginal_log = torch.vstack([marginal_log, gflownet.sample_many_backward(
                            env_i, num_trajectories=num_back_traj)])
        log_rewards = torch.hstack([log_rewards, env_i.log_reward()]) 

    return env, marginal_log, log_rewards   

def unique(x, dim=-1):
    values, inverse, counts = torch.unique(x, return_inverse=True, return_counts=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=values.device)
    # inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return values, inverse, counts, inverse.new_empty(values.size(dim)).scatter_(dim, inverse, perm)

def marginal_dist(env, marginal_log, log_rewards, dim=-1):
    values, inverse, counts, indices = unique(env.unique_input, dim=dim)
    # Compute learned distribution 
    marginal_log_batch = torch.zeros((values.size(0), marginal_log.shape[1]), device=values.device)
    marginal_log_batch.scatter_add_(dim=0, index=inverse.view(-1, 1), src=marginal_log.exp())
    marginal_log_batch = marginal_log_batch.sum(dim=-1)
    marginal_log_batch /= (counts * marginal_log.shape[1]) 
    learned_dist = marginal_log_batch / marginal_log_batch.sum()  
    # Compute the target distribution 
    target_dist = (log_rewards[indices] - torch.logsumexp(log_rewards[indices], dim=0)).exp() 
    return learned_dist, target_dist 

def compute_marginal_dist(gflownet, create_env, num_batches, num_back_traj, use_progress_bar=False): 
    # Sample from the learned distribution 
    samples, marginal_log, log_rewards = sample_massive_batch(gflownet, create_env, num_batches, 
                                                             num_back_traj, use_progress_bar) 
    return marginal_dist(samples, marginal_log, log_rewards, dim=0) 
