import torch 
import torch.nn as nn 
import tqdm 
import numpy as np 

from .envs import Grid2D, LogRewardDist 
from .flows import ForwardFlow, BackwardFlow, StateFlow  

import matplotlib.pyplot as plt 

import matplotlib.pyplot as plt 

def plot_samples_grid2d(gflownet, num_samples, width, height, log_reward): 
    target_dist = list() 
    for i in range(width): 
        target_dist.append(list()) 
        for j in range(height): 
            grid = Grid2D(width, height, batch_size=1, log_reward=log_reward) 
            grid.pos = torch.tensor([[i, j]]) 
            target_dist[i].append(grid.log_reward().exp().item())
    plt.subplot(1, 2, 1) 
    target_dist = np.array(target_dist) 
    target_dist = target_dist / target_dist.sum() 
    plt.imshow(target_dist, vmin=0) 
    plt.title(f'Targeted') 

    grid = Grid2D(width, height, batch_size=num_samples, log_reward=log_reward) 
    grid_samples = gflownet.sample(grid) 
    grid_samples.pos 
    # Check further environments 
    learned_dist = list() 
    for i in range(width): 
        learned_dist.append(list()) 
        for j in range(height): 
            c = (grid_samples.pos[:, 0] == i) & (grid_samples.pos[:, 1] == j) 
            learned_dist[i].append(c.sum().item()) 
    plt.subplot(1, 2, 2) 
    learned_dist = np.array(learned_dist) 
    learned_dist = learned_dist / learned_dist.sum() 
    plt.imshow(learned_dist, vmin=0)
    plt.title('Learned') 
    plt.tight_layout() 
    return target_dist, learned_dist 

@torch.no_grad() 
def create_env(log_reward=None, batch_size=None, width=None, height=None, **kwargs): 
    return Grid2D(width, height, log_reward=log_reward, batch_size=batch_size) 

@torch.no_grad() 
def unique_smp(samples): 
    _, indices, counts = np.unique(samples.pos.cpu(), axis=0, return_index=True, return_counts=True) 
    return indices, counts 

@torch.no_grad() 
def create_gfn(client=None, width=None, height=None, hidden_dim=None, device='cpu', **kwargs): 
    if client is not None: 
        g = torch.Generator(device=device) 
        g.manual_seed((client + 1) * 42) 
        xs = torch.randint(width + 1, size=(1,), generator=g) 
        ys = torch.randint(height + 1, size=(1,), generator=g) 
        refs = torch.vstack([xs, ys]).t() 
        
        # alpha = .75 
        # refs = torch.tensor([
        #     [width * (1 - alpha), height * alpha], 
        #     [width * alpha, height * (1 - alpha)], 
        #     [width * alpha, height * alpha], 
        # ])
        log_reward = LogRewardDist(refs) 
    else: 
        log_reward = None 
    
    forward_flow = ForwardFlow(hidden_dim) 
    backward_flow = BackwardFlow() 

    return forward_flow, backward_flow, log_reward 

@torch.no_grad() 
def create_state_flow(hidden_dim, **kwargs): 
    state_flow = StateFlow(hidden_dim) 
    return state_flow 

# def create_env(width, height, batch_size, refs, log_reward): 
#     return Grid2D(width, height, batch_size, refs=refs, log_reward=log_reward) 

# plot_samples(gflownet, int(1e5), width, height, temperature=temperature)  
# def train_grid2d(gflownet, epochs=int(1e3), width=12, height=12, batch_size=512, temperature=.5, lr=1e-3, refs=None, log_reward=None): 
#     optimizer = torch.optim.Adam(gflownet.parameters(), lr=lr) 
    
#     pbar = tqdm.tqdm(range(epochs)) 

#     for _ in pbar: 
#         optimizer.zero_grad() 
#         grid = Grid2D(width, height, batch_size, temperature=temperature, refs=refs, log_reward=log_reward) 
#         loss = gflownet(grid) 
#         loss.backward() 
#         optimizer.step() 
#         pbar.set_postfix(loss=loss)  
    
#     return gflownet 