import torch
import numpy as np
import torch.nn.functional
import math
import os
from models import gradient, ImplicitNet
import matplotlib.pyplot as plt
from PIL import Image

def plot_2D(network, input_data, output_data, T, file_name, device='cpu', batch_size=5000):
    indices = torch.tensor(np.random.choice(input_data.shape[0], batch_size, False))
    input_pnts = torch.tensor(input_data[indices,:],dtype=torch.float32, requires_grad=True)
    indices2 = torch.tensor(np.random.choice(output_data.shape[0], batch_size, False))
    output_pnts = torch.tensor(output_data[indices2,:],dtype=torch.float32, requires_grad=True)

    ## Transport input data to output distribution
    pnts_input = torch.tensor(torch.cat((torch.zeros((batch_size,1)), input_pnts), 1), requires_grad=True).to(device)
    pred_sol = network(pnts_input)
    grad_input = gradient(pnts_input,pred_sol)[:,1:] 
    transported_pnts = pnts_input[:,1:] + T*grad_input

    ## Transport output data to input distribution
    pnts_output = torch.tensor(torch.cat((torch.ones((batch_size,1)), output_pnts), 1), requires_grad=True).to(device)
    pred_sol_output = network(pnts_output)
    grad_output = gradient(pnts_output,pred_sol_output)[:,1:]
    transported_pnts_oi = pnts_output[:,1:] - T*grad_output
        

    transported_pnts = transported_pnts.detach().cpu().numpy()
    input_pnts = input_pnts.detach().cpu().numpy() 
    output_pnts = output_pnts.detach().cpu().numpy() 
            
    fig,ax = plt.subplots(1,2, figsize=(12,5))

    mid_io = (input_pnts+transported_pnts)/2
    size_io = np.linalg.norm(mid_io,axis=-1)
    step = 2
    ax[0].set_title(f'Forward', fontsize=12)
    for i in range(0,batch_size,step):
        ax[0].plot([input_pnts[i,0], transported_pnts[i,0]], [input_pnts[i,1], transported_pnts[i,1]],color='black', linewidth='0.5' )   
        ax[0].quiver(mid_io[i,0],mid_io[i,1], (transported_pnts[i,0]-input_pnts[i,0])/size_io[i], (transported_pnts[i,1]-input_pnts[i,1])/size_io[i],color='black', headwidth=30, headlength=30, width=0.001 )   
        # ax[0].arrow(input_pnts[i,0], input_pnts[i,1], transported_pnts[i,0]-input_pnts[i,0], transported_pnts[i,1]-input_pnts[i,1], fc='k', ec='k', head_width=0.01, width=0.001) #, linewidth='0.5' )     
        
    ax[0].scatter(input_pnts[:,0], input_pnts[:,1], color='tomato', edgecolors='black', s=50, alpha=1)
    ax[0].scatter(transported_pnts[:,0], transported_pnts[:,1], color='deepskyblue', edgecolors='black', s=40, alpha=1)   

    transported_pnts = transported_pnts_oi.detach().cpu().numpy()
    mid_oi = (output_pnts+transported_pnts)/2
    size_oi = np.linalg.norm(mid_oi,axis=-1)
    for i in range(0,batch_size,step):
        ax[1].plot([output_pnts[i,0], transported_pnts[i,0]], [output_pnts[i,1], transported_pnts[i,1]],color='black', linewidth='0.5' )   
        ax[1].quiver(mid_oi[i,0],mid_oi[i,1], (transported_pnts[i,0]-output_pnts[i,0])/size_oi[i], (transported_pnts[i,1]-output_pnts[i,1])/size_oi[i],color='black', headwidth=30, headlength=30, width=0.001 )   
        
    ax[1].scatter(output_pnts[:,0], output_pnts[:,1], color='deepskyblue', edgecolors='black', s=50, alpha=1)
    ax[1].scatter(transported_pnts[:,0], transported_pnts[:,1], color='tomato', edgecolors='black', s=40, alpha=1)   
    ax[1].set_title(f'Backwoard', fontsize=12)
    plt.savefig(file_name, dpi=150)
    plt.close(fig)
    
  
def plot_transported_colors(network, init_data,target_data, height_init, width_init, height_target, width_target, exp_dir, epoch, device):  
    init_spatialtemporal_pnts = torch.tensor(torch.cat((torch.zeros((init_data.shape[0],1)), init_data), 1), requires_grad=True).to(device)
    T_x = torch.clip(255*(init_data.to(device) + T*network.push_nograd(init_spatialtemporal_pnts)[:,1:]), min=0, max=255)
    T_x = np.array(T_x.detach().cpu().tolist(), dtype=np.uint8).reshape((height_init, width_init, 3))
    T_x = Image.fromarray(T_x, 'RGB')   
    T_x.save(os.path.join(exp_dir,f'Forward_ep{epoch}.jpg')) 
    del init_spatialtemporal_pnts, T_x

    spatialtemporal_pnts = torch.tensor(torch.cat((torch.ones((target_data.shape[0],1)), target_data), 1), requires_grad=True).to(device)
    T_y = torch.clip(255*(target_data.to(device) - T*network.push_nograd(spatialtemporal_pnts)[:,1:]), min=0, max=255)
    T_y = np.array(T_y.detach().cpu().tolist(), dtype=np.uint8).reshape((height_target, width_target, 3))
    T_y = Image.fromarray(T_y, 'RGB')   
    T_y.save(os.path.join(exp_dir,f'Backward_ep{epoch}.jpg')) 
    del spatialtemporal_pnts, T_y