import torch 
import torch.nn as nn
import os 
import tqdm 
import wandb 
import copy 
import numpy as np 

from abc import ABCMeta, abstractmethod 

from gfn.gflownet import LEDGFlowNet 

class MockOptimizer(torch.optim.Optimizer): 

    def __init__(self): 
        self.param_groups = list() 

    def step(self): 
        pass 
    
    def zero_grad(self): 
        pass 

class OptGroup: 

    def __init__(self): 
        self.optimizers = dict() 
        
    def register(self, opt, name): 
        self.optimizers[name] = opt 

    def __getattr__(self, name): 
        return self.optimizers[name] 
    
class SchGroup: 

    def __init__(self): 
        self.schedulers = dict() 

    def register(self, sch, name): 
        self.schedulers[name] = sch 
    
    def __getattr__(self, name): 
        return self.schedulers[name] 

def train_phi_step(gfn, create_env, config, opt_phi, sch_phi):
    for _ in range(config.train_phi_rounds): 
        opt_phi.zero_grad() 
        env = create_env() 
        _, loss_ls = gfn(env) 
        loss_ls.backward()  
        opt_phi.step()          
    sch_phi.step() 
    return loss_ls 

class ModesList:

    def __init__(self, q=.9, th=None): 
        self.th = th 
        self.q = q 
        self.modes = list() 

    def append(self, env): 
        assert self.th is not None 
        log_reward = env.log_reward() 
        is_mode = log_reward >= self.th
        # Check if modes are already in list 
        if len(self.modes) != 0:         
            is_in_mode_lst = (
                torch.vstack(self.modes).view(1, -1, env.unique_input.shape[1]) == env.unique_input.view(env.batch_size, 1, -1)
            ).all(dim=-1).any(dim=1) 
        else: 
            is_in_mode_lst = torch.zeros_like(is_mode, dtype=bool)

        self.modes.append(
            env.unique_input[is_mode & (~is_in_mode_lst)] 
        )
        
    def warmup(self, gfn, create_env_func, epochs):  
        log_reward_lst = list() 
        for _ in range(epochs): 
            env = gfn.sample(create_env_func())  
            log_reward_lst.append(env.log_reward()) 
        log_reward_lst = torch.hstack(log_reward_lst) 
        self.th = torch.quantile(log_reward_lst, q=self.q)  
    
    def __len__(self): 
        if len(self.modes) == 0: 
            return 0 
        return len(torch.vstack(self.modes)) 


def train_step(gfn, create_env, config, opt_group, sch_group, use_wandb=True, mode_lst: ModesList = None): 

    loss_ls = torch.nan 
    for epoch in (pbar := tqdm.trange(config.epochs_per_step, disable=config.disable_pbar)): 
        if config.learn_potential: 
            loss_ls = train_phi_step(gfn, create_env, config, opt_group.opt_phi, sch_group.sch_phi) 
    
        opt_group.opt_gfn.zero_grad() 
        opt_group.opt_gamma.zero_grad() 

        env = create_env() 
        loss = gfn(env) 

        if isinstance(gfn, LEDGFlowNet): 
            loss, loss_ls = loss 
        else: 
            loss, loss_gamma = loss 
            if config.gamma_func == 'learn' and config.criterion in ['td', 'regdb']: 
                loss_gamma.backward(retain_graph=True) 
                    
        loss.backward() 
        opt_group.opt_gamma.step() 
        opt_group.opt_gfn.step()

        sch_group.sch_gfn.step() 
        sch_group.sch_gamma.step() 
        
        log = {'loss_gfn': loss, 'loss_ls': loss_ls}
        if use_wandb: 
            wandb.log(log) 
        pbar.set_postfix(**log)  
        
        if config.criterion == 'td': 
            gfn.gamma_func.step() 

        if mode_lst is not None: 
            mode_lst.append(env) 
            
    return gfn 
        
class Environment(metaclass=ABCMeta):

    def __init__(self, batch_size: int, max_trajectory_length: int, log_reward: nn.Module, device: torch.device):
        self.batch_size = batch_size
        self.device = device 
        self.max_trajectory_length = max_trajectory_length
        self.batch_ids = torch.arange(self.batch_size, device=self.device) 
        self.traj_size = torch.ones((self.batch_size,), device=self.device)         
        self.stopped = torch.zeros((self.batch_size), device=self.device)
        self.is_initial = torch.ones((self.batch_size,), device=self.device)
        self._log_reward = log_reward

    @abstractmethod
    def apply(self, actions: torch.Tensor):
        pass

    @abstractmethod
    def backward(self, actions: torch.Tensor):
        pass

    @torch.no_grad()
    def log_reward(self):
        return self._log_reward(self)
    
    @torch.no_grad() 
    def merge(self, batch_state): 
        self.batch_ids = torch.hstack([self.batch_ids, batch_state.batch_ids]) 
        self.batch_size += batch_state.batch_size 
        self.stopped = torch.hstack([self.stopped, batch_state.stopped]) 
        self.is_initial = torch.hstack([self.is_initial, batch_state.is_initial]) 
    
    @property 
    def unique_input(self): 
        raise NotImplementedError 

class BaseNN(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_layers, output_dim=None, act=nn.LeakyReLU()): 
        super(BaseNN, self).__init__() 
        self.input_dim = input_dim 
        self.hidden_dim = hidden_dim 
        self.num_layers = num_layers 
        self.output_dim = output_dim if output_dim is not None else hidden_dim 

        self.model = nn.Sequential(nn.Linear(input_dim, hidden_dim)) 
        for layer in range(num_layers): 
            self.model.append(act) 
            if layer == num_layers - 1: 
                self.model.append(nn.Linear(self.hidden_dim, self.output_dim)) 
            else: 
                self.model.append(nn.Linear(self.hidden_dim, self.hidden_dim)) 
    
    def forward(self, x): 
        return self.model(x) 
    
class ForwardPolicyMeta(nn.Module, metaclass=ABCMeta): 
    
    masked_value = -1e5 

    def __init__(self, eps=.3, device='cpu'): 
        super(ForwardPolicyMeta, self).__init__() 
        self.eps = eps 
        self.seed = None
        self.device = device 

    @abstractmethod 
    def get_latent_emb(self): 
        pass 

    @abstractmethod 
    def get_pol(self): 
        pass 

    def set_seed(self, seed): 
        self.seed = seed 
    
    def unset_seed(self): 
        self.seed = None 

    def get_actions(self, pol, mask=None): 
        if mask is None: 
            uniform_pol = torch.ones_like(pol) 
        else: 
            uniform_pol = torch.where(mask==1., 1., 0.)
            uniform_pol = uniform_pol / uniform_pol.sum(dim=1, keepdims=True)  

        eps = 0. if not self.training else self.eps 
        exp_pol = pol * (1 - eps) + eps * uniform_pol 
        
        if self.seed is not None: 
            g = torch.Generator(device=self.device) 
            g.manual_seed(self.seed) 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True, generator=g) 
        else: 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True) 
        actions = actions.squeeze(dim=-1) 
        return actions, exp_pol  

    def forward(self, batch_state, actions=None, return_ps=False):
        if not hasattr(batch_state, 'forward_mask'): 
            batch_state.forward_mask = None  
        latent_emb = self.get_latent_emb(batch_state) 
        pol, gflows = self.get_pol(latent_emb, batch_state.forward_mask) 
        spol = pol # sampling policy  
        if actions is None: actions, spol = self.get_actions(pol, batch_state.forward_mask) 
        if return_ps: 
            return actions, \
                torch.log(pol[batch_state.batch_ids, actions]), \
                torch.log(pol[batch_state.batch_ids, -1]), \
                    torch.log(spol[batch_state.batch_ids, actions]) 
        else: 
            return actions, \
                torch.log(pol[batch_state.batch_ids, actions]), \
                gflows, \
                torch.log(spol[batch_state.batch_ids, actions]) 

class GammaFuncMeta(nn.Module): 
    
    def __init__(self, temperature=1., total_iters=1): 
        super(GammaFuncMeta, self).__init__() 
        self.temperature = temperature   
        self.total_iters = total_iters 
        
        self.step_len = self.temperature / self.total_iters 
    
    @torch.no_grad() 
    def forward(self, batch_state_t, batch_state_tp1):
        weights = self.weight_func(batch_state_t, batch_state_tp1) 
        return self.temperature * weights 

    def step(self): 
        self.temperature = max(self.temperature - self.step_len, 0.) 
    
class GammaFuncConst(GammaFuncMeta): 

    def weight_func(self, batch_state_t, batch_state_tp1): 
        return torch.ones(batch_state_t.batch_size, device=batch_state_t.device)   

def log_artifact_tensor(tensor, artifact_name): 
    torch.save(tensor, artifact_name)  
    artifact = wandb.Artifact(artifact_name, type='tensor') 
    artifact.add_file(artifact_name) 
    wandb.run.log_artifact(artifact) 
    
def load_artifact_tensor(artifact_name): 
    artifact = wandb.run.use_artifact(f'weekday/{os.environ["WANDB_PROJECT_NAME"]}/{artifact_name}:latest', 
                                      type='tensor') 
    artifact_dir = artifact.download() 
    return torch.load(os.path.join(artifact_dir, artifact_name)) 

def log_artifact_module(module, artifact_name, artifact_filename): 
    torch.save(module.state_dict(), artifact_filename) 
    artifact = wandb.Artifact(artifact_name, type='model') 
    artifact.add_file(artifact_filename) 
    wandb.run.log_artifact(artifact) 

def load_artifact_module(module, artifact_name, artifact_filename): 
    artifact = wandb.run.use_artifact(f'weekday/{os.environ["WANDB_PROJECT_NAME"]}/{artifact_name}:latest', 
                                        type='model') 
    artifact_dir = artifact.download() 
    module.load_state_dict(torch.load(os.path.join(artifact_dir, 
                                                   os.path.basename(artifact_filename)))) 
    return module 

@torch.no_grad() 
def sample_massive_batch(gflownet, create_env, num_batches, num_back_traj=1, use_progress_bar=True): 
    env = create_env() 
    env = gflownet.sample(env) 
    marginal_log = gflownet.sample_many_backward(env, num_trajectories=num_back_traj) 
    log_rewards = env.log_reward() 

    for _ in tqdm.trange(num_batches - 1, disable=not use_progress_bar): 
        env_i = create_env()
        env_i = gflownet.sample(env_i) 
        env.merge(env_i) 
        marginal_log = torch.vstack([marginal_log, gflownet.sample_many_backward(
                            env_i, num_trajectories=num_back_traj)])
        log_rewards = torch.hstack([log_rewards, env_i.log_reward()]) 

    return env, marginal_log, log_rewards   

def unique(x, dim=-1):
    values, inverse, counts = torch.unique(x, return_inverse=True, return_counts=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=values.device)
    # inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return values, inverse, counts, inverse.new_empty(values.size(dim)).scatter_(dim, inverse, perm)

def marginal_dist(env, marginal_log, log_rewards, dim=-1, normalize=True):
    values, inverse, counts, indices = unique(env.unique_input, dim=dim)
    # Compute learned distribution 
    marginal_log_batch = torch.zeros((values.size(0), marginal_log.shape[1]), device=values.device)
    marginal_log_batch.scatter_add_(dim=0, index=inverse.view(-1, 1), src=marginal_log.exp())
    marginal_log_batch = marginal_log_batch.sum(dim=-1)
    marginal_log_batch /= (counts * marginal_log.shape[1]) 

    if normalize: 
        learned_dist = marginal_log_batch / marginal_log_batch.sum()  
        # Compute the target distribution 
        target_dist = (log_rewards[indices] - torch.logsumexp(log_rewards[indices], dim=0)).exp() 
    else: 
        learned_dist = marginal_log_batch 
        target_dist = log_rewards[indices] 
    
    return learned_dist, target_dist 

def compute_marginal_dist(gflownet, create_env, num_batches, num_back_traj, use_progress_bar=False): 
    # Sample from the learned distribution 
    samples, marginal_log, log_rewards = sample_massive_batch(gflownet, create_env, num_batches, 
                                                             num_back_traj, use_progress_bar) 
    return marginal_dist(samples, marginal_log, log_rewards, dim=0) 

@torch.no_grad() 
def compute_marginal_grid(gfn, create_env_func, num_back_traj): 
    env = create_env_func() 
    # Ensure the correct batch size 
    env.batch_size = (env.width + 1) * (env.height + 1)  
    env.batch_ids = torch.arange(env.batch_size, device=env.device) 
    env.forward_mask = torch.zeros((env.batch_size, 3), device=env.device) 
    env.backward_mask = torch.zeros((env.batch_size, 2), device=env.device) 
    env.stopped = torch.ones((env.batch_size,), device=env.device) 
    env.is_initial = torch.zeros((env.batch_size,), device=env.device) 

    env.pos = torch.meshgrid(
        torch.arange(env.width + 1, device=env.device), torch.arange(env.height + 1, device=env.device) 
    )  
    env.pos = torch.cat(
        (env.pos[0].unsqueeze(0), env.pos[1].unsqueeze(0)), dim=0   
    ).flatten(start_dim=1).t().type(env.forward_mask.dtype) 

    # env.update_forward_mask() 
    env.update_backward_mask() 

    env.forward_mask[:, -1] = 1 

    marginal_log = gfn.sample_many_backward(env, num_back_traj)
    learned_log_prob = torch.logsumexp(marginal_log, dim=1) - np.log(num_back_traj) 
    targetd_log_prob = env.log_reward() 

    learned_log_prob = learned_log_prob - torch.logsumexp(learned_log_prob, dim=0) 
    targetd_log_prob = targetd_log_prob - torch.logsumexp(targetd_log_prob, dim=0)  
    return learned_log_prob.exp(), targetd_log_prob.exp(), env.pos 

@torch.no_grad() 
def compute_fcs(gflownet, create_env, config): 
    if config.env == 'grids': 
        learned_prob, target_prob, _ = compute_marginal_grid(gflownet, create_env, num_back_traj=config.num_back_traj) 
        return (learned_prob - target_prob).abs().sum() / 2. 

    l1_lst = list() 
    bs = copy.copy(config.batch_size) 

    config.batch_size = min(config.fcs_bucket_size, config.batch_size)  
    
    for idx in tqdm.trange(config.epochs_eval, disable=config.disable_pbar): 
        samples, marginal_log, log_rewards = sample_massive_batch(
                    gflownet, 
                    create_env, 
                    # Note that this is num_iterations - 1 
                    num_batches=(config.fcs_bucket_size // config.batch_size) - 1, 
                    num_back_traj=config.num_back_traj,  
                    use_progress_bar=False) 
        assert samples.batch_size == config.fcs_bucket_size 
        # Compute the L1 distance for the sampled subset 
        learned_dist, target_dist = marginal_dist(samples, marginal_log, log_rewards, dim=0) 
        l1 = (learned_dist - target_dist).abs().sum() 
        l1_lst.append(l1) 
    # Re-assign the initial batch size 
    config.batch_size = bs 
    l1_lst = .5 * torch.tensor(l1_lst) 
    return l1_lst.mean() 

@torch.no_grad() 
def compute_tv(gflownet, create_env, config, verbose=True, return_dist=False): 
    env = create_env() 
    
    learned_dist_lst = list()
    target_dist_lst = list() 
    states = list() 
    
    if verbose: print('Enumerating the state space') 
    for batch_state in env.list_all_states(max_batch_size=config.batch_size):  
        assert batch_state.batch_size <= config.batch_size 
        log_rewards = batch_state.log_reward()  
        marginal_log = gflownet.sample_many_backward(batch_state, num_trajectories=config.num_back_traj)
        learned_dist, target_dist = marginal_dist(batch_state, 
                    marginal_log, log_rewards, dim=0, normalize=False) 
        learned_dist_lst.append(learned_dist) 
        target_dist_lst.append(target_dist) 
    
    if verbose: print('Computing the L1 norm') 
    learned_dist = torch.hstack(learned_dist_lst) 
    target_dist = torch.hstack(target_dist_lst) 

    learned_dist = (torch.log(learned_dist) - torch.logsumexp(torch.log(learned_dist), dim=0)).exp()   
    target_dist = (target_dist - torch.logsumexp(target_dist, dim=0)).exp() 

    tv = .5 * (learned_dist - target_dist).abs().sum() 

    if return_dist: 
        return tv, learned_dist, target_dist 
    else: 
        return tv 
    