# 2D Transient Heat Equation for steel plate solver via finite-difference scheme
# Author: Leonardo Antonio de Araujo

import numpy as np
import math
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import matplotlib.animation as ani
from tqdm import tqdm
from matplotlib import cm


def simulate_heat_equation (problem_dict, grid_dict, radar=False):

	N = grid_dict['N']
	NUM_T = grid_dict['NUM_T']

	# initialize solution
	T = np.zeros((N, N, NUM_T))

	# set initial condition
	T = initial_condition_flower(T, problem_dict, grid_dict)

	# solve heat equation
	T = heat_equation_solver(T, problem_dict, grid_dict)

	if radar:
		radar_data = extract_radar_data (T, grid_dict)
		return T, radar_data
	else:
		return T


def initial_condition_disk (T, problem_dict, grid_dict):

	# problem parameters
	T0 = problem_dict['T0'] # float, initial non-zero temperature

	# grid parameters
	N = grid_dict['N']      # int, grid points along each axis
	R = grid_dict['R']      # int, radius of heat source (in terms of indeces)

	center = int((N-1)/2)   # int, center initial non-zero temperature
							# N is odd so the center temperature is actually in the center of the square
	for i in range(center-R,center+R+1):
		for j in range(center-R,center+R+1):
			if np.hypot(i-center, j-center) <= R:
				T[i,j,0] = T0

	return T


def initial_condition_ring (T, problem_dict, grid_dict):

	# problem parameters
	T0 = problem_dict['T0'] # float, initial non-zero temperature

	# grid parameters
	N = grid_dict['N']      # int, grid points along each axis
	R = grid_dict['R']      # int, radius of heat source (in terms of indeces)
	r = grid_dict['r']

	center = int((N-1)/2)   # int, center initial non-zero temperature
							# N is odd so the center temperature is actually in the center of the square
	for i in range(center-R,center+R+1):
		for j in range(center-R,center+R+1):
			cond_R = np.hypot(i-center, j-center) <= R
			cond_r = np.hypot(i-center, j-center) >= r
			if cond_r and cond_R:
				T[i,j,0] = T0

	return T

def initial_condition_flower (T, problem_dict, grid_dict):

	# problem parameters
	T0 = problem_dict['T0'] # float, initial non-zero temperature

	# grid parameters
	N = grid_dict['N']      # int, grid points along each axis
	R = grid_dict['R']      # int, radius of heat source (in terms of indeces)
	r = grid_dict['r']

	center = int((N-1)/2)
	for i in range(center-R,center+R+1): # watch out: this cuts the "flower" shape within a square
		for j in range(center-R,center+R+1):
			cond_R = np.hypot(i-center, j-center) <= R + 0.3 * R * np.sin(4 * np.arctan2(j-center, i-center) + math.pi * 0.5)
			cond_r = np.hypot(i-center, j-center) >= r + 0.3 * r * np.sin(4 * np.arctan2(j-center, i-center) + math.pi * 0.5)
			if cond_r and cond_R:
				T[i,j,0] = T0


	return T

def heat_equation_solver(T, problem_dict, grid_dict):

	# problem parameters
	MAX_T = problem_dict['MAX_T']
	K = problem_dict['K']
	L = problem_dict['L']

	# grid parameters
	NUM_T = grid_dict['NUM_T']
	N = grid_dict['N']

	dx = L / N
	dt = MAX_T / NUM_T
	for t in tqdm(range (0, int(NUM_T)-1)):
		for i in range(1, N-1):
			for j in range (1, N-1):
				a = (T[i+1,j,t] - 2*T[i,j,t] + T[i-1,j,t]) / dx**2 # d2dx2
				b = (T[i,j+1,t] - 2*T[i,j,t] + T[i,j-1,t]) / dx**2 # d2dy2
				T[i,j,t+1] = K * dt * (a+b) + T[i,j,t]

	return T


def extract_radar_data (T, radars_dict):

	n = radars_dict['N']
	r = radars_dict['RR']

	# radar stations data
	length = ((n-1)/2-1)/2
	length = int(length)+1
	radars_centers = [(length,length), (3*length,length), (length,3*length),
			 			(2*length, 2*length), (3*length,3*length)]

	indeces = []
	for c_x, c_y in radars_centers:
		for i in range(c_x-r, c_x+r+1):
			for j in range(c_y-r, c_y+r+1):
				if np.hypot(i-c_x, j-c_y) <= r:
					indeces.append((i,j))

	radar_temp = np.array([T[i,j,:] for i, j in indeces])
	radar_data = dict(T=radar_temp, indeces=np.array(indeces))

	return radar_data


def plot_snap (T, X, Y, t):
	fig = plt.figure()
	ax = fig.add_subplot(111, projection='3d')
	ax.plot_surface(X, Y, T[:,:,t], cmap='gist_rainbow_r', edgecolor='none')

	ax.set_zlim(0, T.max())
	ax.set_xlabel('X [m]')
	ax.set_ylabel('Y [m]')
	ax.set_zlabel('T [°]')
	plt.show()
