## Import libraries.
import os
import sys
import math
import numpy as np
import scipy.io
# from tqdm import tqdm
from time import time
from collections import OrderedDict
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

import torch
import torch.nn as nn
import torch.nn.functional as F

pi = math.pi
cos = torch.cos
sin = torch.sin

# import scipy
from tabulate import tabulate

## Numerical simulation

class Numeric():
	def __init__(self, args, parameters_real, Conformal_coefficients):
		super(Numeric, self).__init__()
		self.sigma = args.sigma
		self.parameters = parameters_real
		self.Conformal_coefficients = Conformal_coefficients
		self.Matrix_truncation = args.Matrix_truncation
		self.radius = args.R_bd
		self.arg_field = args.arg_field
		self.complex_background = args.complex_background

	def Faber_poly_coeff(self):

		a = self.Conformal_coefficients
		dummy = torch.zeros(2*self.Matrix_truncation, 1)
		a = torch.cat((a, dummy), dim=0)  # if kite, a = [0, 0.1, 0.25, -0.05, 0.05, -0.04, 0.02][0, 0, ... 0] ^ T

		column = torch.zeros(3*self.Matrix_truncation+1,1)  # p_zeroth_column
		column[1,0] = 1
		column[2,0] = -a[0, 0]
		P = torch.zeros(2*self.Matrix_truncation+2,2*self.Matrix_truncation+2)

		for m in range(1, 2*self.Matrix_truncation+1):
			column[m+2, 0] = -m*a[m, 0] - a[m, 0] - a[0, 0]*column[m+1, 0]
			for k in range(1, m):
				column[m+2, 0] = column[m+2, 0] - a[m-k, 0]*column[k+1, 0]
			P[m, m] = 1
			P[m+1, m] = -(m+1)*a[0, 0]

		for m in range(1, 2*self.Matrix_truncation+1):
			P[m+1, 1] = column[m+1] - a[0, 0]*P[m, 1]
			for k in range(1, m):
				P[m+1, 1] = P[m+1, 1] - a[m-k, 0]*P[k, 1]

		for m in range(1, self.Matrix_truncation+1):
			for n in range(2, m+1):
				P[m+1, n] = P[m, n-1] - a[0, 0]*P[m, n]
				for k in range(n, m):
					P[m+1, n] = P[m+1, n] - a[m-k,0]*P[k, n]

		P = P[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]

		return P

	def Grunsky(self):
		C = torch.zeros(2*self.Matrix_truncation+1,2*self.Matrix_truncation+1)
		a = self.Conformal_coefficients

		dummy = torch.zeros(2*self.Matrix_truncation, 1)
		a = torch.cat((a, dummy), dim=0)

		for m in range(1, 2*self.Matrix_truncation+1):
			C[1, m] = a[m,0]
			C[m, 1] = m*a[m,0]

		for m in range(1, self.Matrix_truncation+1):
			for k in range(1, self.Matrix_truncation+1):
				C[m+1, k] = C[m, k+1] + a[m+k,0]
				if m > 1:
					for n in range(1, m):
						C[m+1, k] = C[m+1, k] - a[m-n,0]*C[n, k]

				if k >1:
					for n in range(1, k):
						C[m+1, k] = C[m+1, k] + a[k-n,0]*C[m, n]

		C = C[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		return C

	def Matrices(self):
		N = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		gamma_N = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		gamma_2N = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		gamma_NN = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		gamma_2NN = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		I = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		r = self.radius

		for n in range(1, self.Matrix_truncation+1):
			N[n, n] = n
			gamma_N[n, n] = r**n
			gamma_2N[n, n] = r**(2*n)
			gamma_NN[n, n] = r**(-1*n)
			gamma_2NN[n, n] = r**(-2*n)
			I[n, n] = 1

		N = N[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		gamma_N = gamma_N[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		gamma_2N = gamma_2N[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		gamma_NN = gamma_NN[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		gamma_2NN = gamma_2NN[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		I = I[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]

		return N, gamma_N, gamma_2N, gamma_NN, gamma_2NN, I

	def P_positive(self):
		P_positive = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		P_positive = P_positive.type(torch.complex64)
		p = self.parameters
		r = self.radius
		for n in range(1, self.Matrix_truncation+1):
			for m in range(1, self.Matrix_truncation+1):
				if m+n < len(p):
					P_positive[m, n] = p[m+n]*r**(m+n)

		P_positive = P_positive[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		return P_positive

	def P_negative(self):
		P_negative = torch.zeros(2*self.Matrix_truncation,2*self.Matrix_truncation)
		P_negative = P_negative.type(torch.complex64)
		p = self.parameters
		r = self.radius

		for n in range(1, self.Matrix_truncation+1):
			for m in range(1, self.Matrix_truncation+1):
				if m == n:
					P_negative[m, n] = p[0]
				elif m>n and m-n < len(p):
					P_negative[m, n] = p[m-n]*r**(m-n)
				elif m<n and n-m < len(p):
					P_negative[m, n] = torch.conj(p[n-m])*r**(m-n)
		P_negative = P_negative[0:self.Matrix_truncation+1, 0:self.Matrix_truncation+1]
		return P_negative


	def my_inv(self, Mat):
		Mat_size_row = Mat.shape[0]
		Mat_size_column = Mat.shape[1]
		Mat = Mat[1:Mat_size_row, 1:Mat_size_column]
		Mat_inv = torch.linalg.pinv(Mat)
		zero_row = torch.zeros(1,Mat_size_row-1)
		zero_col = torch.zeros(Mat_size_column,1)

		Mat_inv = torch.cat((zero_row,Mat_inv), dim=0)
		Mat_inv = torch.cat((zero_col,Mat_inv), dim=1)

		return Mat_inv

	def Matrices_AB(self):
		sigma_c = self.sigma[1]
		sigma_m = self.sigma[0]
		C = self.Grunsky()
		N, gamma_N, gamma_2N, gamma_NN, gamma_2NN, I = self.Matrices()
		C = C.type(torch.complex64)
		N = N.type(torch.complex64)
		gamma_N = gamma_N.type(torch.complex64)
		gamma_2N = gamma_2N.type(torch.complex64)
		gamma_NN = gamma_NN.type(torch.complex64)
		gamma_2NN = gamma_2NN.type(torch.complex64)
		I = I.type(torch.complex64)

		PP = self.P_positive()
		PN = self.P_negative()

		A1 = (sigma_c-sigma_m)*gamma_N@torch.conj(PP)@gamma_N + (sigma_c-sigma_m)*C@gamma_NN@PN@gamma_N + sigma_c*sigma_m*C@N
		A2 = (sigma_c-sigma_m)*gamma_N@torch.conj(PN)@gamma_N + (sigma_c-sigma_m)*C@gamma_NN@PP@gamma_N - sigma_c*sigma_m*gamma_2N@N
		B1 = ((sigma_c-sigma_m)*I + 2*sigma_m*self.my_inv(I-gamma_2NN@torch.conj(C)@gamma_2NN@C))@gamma_NN@PP@gamma_N + 2*sigma_m*self.my_inv(I-gamma_2NN@torch.conj(C)@gamma_2NN@C)@gamma_2NN@C@gamma_NN@(PN)@gamma_N
		B2 = ((sigma_c-sigma_m)*I + 2*sigma_m*self.my_inv(I-gamma_2NN@torch.conj(C)@gamma_2NN@C))@gamma_NN@PN@gamma_N + 2*sigma_m*self.my_inv(I-gamma_2NN@torch.conj(C)@gamma_2NN@C)@gamma_2NN@torch.conj(C)@gamma_NN@torch.conj(PP)@gamma_N + sigma_c*sigma_m*N

		return A1, A2, B1, B2

	def Matricse_S_beta(self):

		alpha = torch.zeros(self.Matrix_truncation+1,self.Matrix_truncation+1);
		beta = torch.zeros(self.Matrix_truncation+1,self.Matrix_truncation+1);
		alpha = alpha.type(torch.complex64)
		beta = beta.type(torch.complex64)

		sigma_c = self.sigma[1]
		sigma_m = self.sigma[0]

		A1, A2, B1, B2 = self.Matrices_AB()
		C = self.Grunsky()
		N, gamma_N, gamma_2N, gamma_NN, gamma_2NN, I = self.Matrices()
		A1 = A1.type(torch.complex64)
		A2 = A2.type(torch.complex64)
		B1 = B1.type(torch.complex64)
		B2 = B2.type(torch.complex64)
		C = C.type(torch.complex64)
		N = N.type(torch.complex64)
		gamma_N = gamma_N.type(torch.complex64)
		gamma_2N = gamma_2N.type(torch.complex64)
		gamma_NN = gamma_NN.type(torch.complex64)
		gamma_2NN = gamma_2NN.type(torch.complex64)
		I = I.type(torch.complex64)

		if self.arg_field == 0:
			alpha[1,1] = 1;
		elif self.arg_field == 1:
			
			alpha[1,1] = torch.tensor([-1.j])
		
		elif self.arg_field == 2:
			alpha[1,1] = torch.tensor([self.complex_background[0] + self.complex_background[1]*1.j])


		S = -alpha@(A1 - A2@self.my_inv(torch.conj(B2))*torch.conj(B1)) @ self.my_inv(B2 - B1@self.my_inv(torch.conj(B2))@torch.conj(B1)) - torch.conj(alpha)@((torch.conj(A2) - torch.conj(A1)@self.my_inv(torch.conj(B2))@torch.conj(B1) )@self.my_inv(B2-B1@self.my_inv(torch.conj(B2))@torch.conj(B1)))
		beta = sigma_m/sigma_c*alpha - sigma_m/sigma_c*(torch.conj(S)+S@gamma_2NN@torch.conj(C))@self.my_inv(torch.eye(self.Matrix_truncation+1)-gamma_2NN@C@gamma_2NN@torch.conj(C))@gamma_2NN


		return S, beta
	
