from DNN import *
from DNN_positive import *
from CLOUD import *
from PINN import *
from time import time
from Numeric_edit import *


from Init_to_coeffs import *

from symfit import parameters, variables, sin, cos, Fit
from fourier_series import *

from Fourier_fitting_p_coeff import *


def learn(args,device,PATH):
    #### Define our PINN-model.

    
    dnn_int = DNN(args.layers, args.activation).to(device)
    dnn_ext = DNN(args.layers, args.activation).to(device)

    cloud = CLOUD(args.PSI_list[args.arg_conformal], args.eps_bd, args.Conformal_order, args.custom_coeffs, args.Fixed)


    #### Define forward-loss, inverse-loss
    hist_forward = []
    hist_inverse = []

    if args.p_nn_flag == "False":
        #### Define interface coefficients.
        coeffs = torch.zeros(2*args.order_interface+1, 1)

        if args.init_flag == "True":
            coeffs = Init_to_coeffs(args, coeffs)

        inv_param = coeffs.to(device).requires_grad_(True)

        model = PINN(dnn_int, dnn_ext, args, inv_param)

        parameters = list(model.dnn_int.parameters()) + list(model.dnn_ext.parameters())

        optimizer = torch.optim.Adam([{'params': parameters, 'lr':args.lr_pinn}, {'params': inv_param, 'lr':args.lr_inv}])

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=args.gamma) 

    if args.p_nn_flag == "True":
        inv_param = DNN_positive(args).to(device)

        # inv_param = DNN_positive(args.init_flag, args.layers_interface, args.activation_interface).to(device)

        model = PINN(dnn_int, dnn_ext, args, inv_param)

        parameters = list(model.dnn_int.parameters()) + list(model.dnn_ext.parameters()) + list(inv_param.parameters())

        optimizer = torch.optim.Adam([{'params': parameters, 'lr':args.lr_pinn}])

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=args.gamma) 


    #### Start timer
    t0 = time()

    #### Training the model.
    for epoch in range(1,args.num_epochs+1):

        if args.import_flag == "True":
            ## We import collocation points from default settings.
            name_conformal = args.PSI_list[args.arg_conformal]
            if name_conformal == "custom" and len(args.custom_coeffs) == 3:
                name_conformal = "square"
            elif name_conformal == "custom" and len(args.custom_coeffs) == 7:
                name_conformal = "spike1"
            elif name_conformal == "custom" and len(args.custom_coeffs) == 9:
                name_conformal = "spike2"
                
            if args.args_type == "None":
                PATH_import = "results/" + name_conformal + "/"
            else:
                PATH_import = "results/" + name_conformal + "/" + args.args_type +"/"

            PATH_import = PATH_import + "2024_PINN_neutral_test_%d/samples.bin" %(args.import_title)
            
            samples_dict_import = torch.load(PATH_import)

            boundary_disk = samples_dict_import['boundary_disk'].to(device)
            # boundary_disk_tspace = cloud.carte_to_polar(boundary_disk)[:, [1]].to(device)
            exterior_disk = samples_dict_import['exterior_disk'].to(device)
            interior_psi = samples_dict_import['interior_psi'].to(device)
            boundary_psi = samples_dict_import['boundary_psi'].to(device)
            exterior_psi = samples_dict_import['exterior_psi'].to(device)
            normal_psi = samples_dict_import['normal_psi'].to(device)
            trace_int_psi = samples_dict_import['trace_int_psi'].to(device)
            trace_ext_psi = samples_dict_import['trace_ext_psi'].to(device)
            w = samples_dict_import['w']


        elif args.import_flag == "False":
            ## First, generate the data; collocation points.
            ## Below are scattered over the original domain.
            boundary_disk, exterior_disk = cloud.generate_disk_data(args.L_dom, args.R_bd, args.S_carte, args.S_boundary)
            # boundary_disk_tspace = cloud.carte_to_polar(boundary_disk)[:, [1]].to(device)
            boundary_disk, exterior_disk = boundary_disk.to(device), exterior_disk.to(device)
            ## And apply own conformal map here.
            ## These data will be used to compute the losses.
            interior_psi = cloud.generate_data_interior(args.R_bd, args.S_int_angular, args.S_int_radial)
            interior_psi = interior_psi.to(device)
            boundary_psi, exterior_psi, normal_psi, trace_int_psi, trace_ext_psi = cloud.generate_data_bd_ext(boundary_disk, exterior_disk)
            boundary_psi, exterior_psi, normal_psi, trace_int_psi, trace_ext_psi = boundary_psi.to(device), exterior_psi.to(device), normal_psi.to(device), trace_int_psi.to(device), trace_ext_psi.to(device)

            w = torch.zeros(args.W_expansion ,exterior_disk.shape[0],exterior_disk.shape[1]).cpu()
            
            for k in range(1, args.W_expansion):
                w[k,:,:] = cloud.exterior_disk_inverse(exterior_disk,k).cpu()

        # ## Generate angular datas for boundary_disk(For fourier coefficients fitting.) ##
        # if args.p_nn_flag == "True":
            # boundary_disk_tspace = cloud.carte_to_polar(boundary_disk)[:, [1]].detach().cpu()


        ## we should always make the boundary_disk_tspace
        boundary_disk_tspace = cloud.carte_to_polar(boundary_disk)[:, [1]].detach().cpu()


        #######################
        for itr in range(1,args.S_training+1):

            t1 = time()
            optimizer.zero_grad()

            loss = model.get_loss(interior_psi, boundary_disk, exterior_psi, normal_psi, trace_int_psi, trace_ext_psi)

            if args.reg_flag == "True":
                
                if args.p_nn_flag == "False":
                    
                    reg_term = 0

                    r = args.R_bd

                    p0 = model.inv_param[0]

                    reg_term = reg_term + 2*pi*torch.square(p0)

                    for i in range(args.order_interface):
                        deg = (i+1)
                        if deg>=1:
                            reg_term = reg_term + (4*(r**2)*pi) * (1+(deg**2)) * ( torch.square(model.inv_param[2*deg-1]) + torch.square(model.inv_param[2*deg]) )

                elif args.p_nn_flag == "True":
                
                    reg_term = 0
                    state_dict = inv_param.state_dict()

                    for name, param in state_dict.items():
                        if 'weight' in name:

                            reg_term = reg_term + torch.sum(torch.square(param))

                loss = loss + (args.reg_weight)*reg_term

            interface = model.get_interface(boundary_disk)
            if args.p_nn_flag == "True":
                interface = interface[:, [0]]
            interface = interface
            loss_interface = torch.sum(torch.max(torch.zeros(interface.shape).to(device),-interface))
            loss = loss + loss_interface

            loss.backward()
            optimizer.step()
            scheduler.step()

            if itr%args.plot_interval==0:
                one_itr_time=time()-t1

                if args.p_nn_flag == "False":

                    mean_p_imag=torch.mean(torch.abs(model.inv_param[1:, :].reshape(-1,2)[:,1])).detach().cpu()
                    print("### Epoch : %d (Itr : %d (%.3fs)), Loss : %.2E = %.2E + %.2E + %.2E + %.2E // p0 : %.2E, Mean of |p_imag| : %.2E ###"%(epoch, itr, one_itr_time,model.hist[0][-1], model.hist[1][-1], model.hist[2][-1], model.hist[3][-1], model.hist[4][-1], model.inv_param[[0], :].item(), mean_p_imag))
                # print(model.inv_param[1:, :].reshape(-1,2).detach().cpu())

                # data = [
                #     [args.num_epochs, itr, model.hist[0][-1], model.hist[1][-1], model.hist[2][-1], model.hist[3][-1], model.hist[4][-1], model.inv_param[[0], :].item(), model.inv_param[1:, :].reshape(-1,2).detach().cpu()]
                # ]
                # headers = ['Epoch', 'Iteration', 'Total loss', 'Loss 1', 'Loss 2', 'Loss 3', 'Loss 4', 'p0', 'Parameters of p']
                # table = tabulate(data, headers, tablefmt='grid')
                # print('\n')
                # print(table)

                    torch.save({
                    'dnn_int_state_dict': dnn_int.state_dict(),
                    'dnn_ext_state_dict': dnn_ext.state_dict(),
                    'Param_p' : inv_param.detach().cpu(),
                            }, PATH+'/checkpoint.bin')
                
                elif args.p_nn_flag == "True":
                
                    print("### Epoch : %d (Itr : %d (%.3fs)), Loss : %.2E = %.2E + %.2E + %.2E + %.2E / ###"%(epoch, itr, one_itr_time,model.hist[0][-1], model.hist[1][-1], model.hist[2][-1], model.hist[3][-1], model.hist[4][-1]))

                    torch.save({
                    'dnn_int_state_dict': dnn_int.state_dict(),
                    'dnn_ext_state_dict': dnn_ext.state_dict(),
                    'dnn_inv_param' : inv_param.state_dict(),
                            }, PATH+'/checkpoint.bin')
                

            if itr%args.loss_tracking_interval == 0:
                
                if args.p_nn_flag == "False":
                    print('Itr {}, args.lr_pinn {}, args.lr_inv {}'.format(itr, optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr']))
                
                elif args.p_nn_flag == "True":
                    print('Itr {}, args.lr_pinn {}'.format(itr, optimizer.param_groups[0]['lr']))



                ################
                #### Compute the numerical results
                ## (1) Take the basic materials.

                a = cloud.conformal_coefficients()

                if args.p_nn_flag == "False":

                    p0 = inv_param[0][0]

                    M = inv_param.shape[0]

                    N = int((M-1)/2)

                    Re_p = torch.zeros(N, 1)
                    Im_p = torch.zeros(N, 1)

                    for i in range(N):
                        Re_p[i][0] = inv_param[2*i+1][0]
                        Im_p[i][0] = inv_param[2*i+2][0]

                    p = torch.zeros(N + 5, 1)

                    p = p.type(torch.complex64)

                    p[0][0] = p0.item()

                    for i in range(N):
                        p[i+1][0] = Re_p[i]+ 1.j*Im_p[i]

                elif args.p_nn_flag == "True":
                    
                    p_nn_real = inv_param(boundary_disk)[:, [0]].detach().cpu().numpy()
                    
                    # boundary_disk_tspace_for_fourier = boundary_disk_tspace.detach().numpy()

                    p0, Re_p, Im_p = Fourier_fitting_p_coeff(args, boundary_disk_tspace, p_nn_real, args.p_nn_order)

                    p = torch.zeros(args.p_nn_order + 5, 1)

                    p = p.type(torch.complex64)

                    p[0][0] = p0.item()

                    for i in range(args.p_nn_order):
                        p[i+1][0] = Re_p[i]+ 1.j*Im_p[i]

                ################

                ## (2) Define the matrix materials.
                numeric = Numeric(args, p, a)

                # Q = numeric.Faber_poly_coeff()
                S, beta = numeric.Matricse_S_beta()

                ## (3) Compute the PINN, Numerical sources/ and prediction error.
                # "w" tensor is presented above
                if args.arg_field == 0:
                    # along x, here a[0] should be real number

                    k = 1
                    u_p_out = S[1,k]*(w[k,:,0]+1.j*w[k,:,1])

                    for k in range(2, args.W_expansion):
                        u_p_out = u_p_out + S[1,k]*(w[k,:,0]+1.j*w[k,:,1])

                    u_p_out = u_p_out.real

                    u_p_out = u_p_out + exterior_psi[:,0].cpu()-a[0]*torch.ones(exterior_psi.shape[0])

                else:
                    
                    "TEMPORARY"

                    k = 1
                    u_p_out = S[1,k]*(w[k,:,0]+1.j*w[k,:,1])

                    for k in range(2, args.W_expansion):
                        u_p_out = u_p_out + S[1,k]*(w[k,:,0]+1.j*w[k,:,1])

                    u_p_out = u_p_out.real

                    u_p_out = u_p_out + exterior_psi[:,0].cpu()-a[0]*torch.ones(exterior_psi.shape[0])
               
                # along y
                # update more:DH

                u_NN_exterior = model.get_u_ext(exterior_psi).detach().cpu().reshape(-1)
                u_p_out = u_p_out.detach().cpu().reshape(-1)


                Numerics_L2_norm = torch.mean(torch.square(u_p_out))
                prediction_error = torch.mean(torch.square(u_NN_exterior - u_p_out))
                Relative_L2_error = torch.divide(prediction_error, Numerics_L2_norm)

                H = model.background_field(exterior_psi).detach().cpu().reshape(-1)

                Background_L2_error = torch.mean(torch.square(u_NN_exterior - H))  # error = (u_NN - H)

                model.hist_numeric.append(Relative_L2_error.item())
                
                print("### Epoch : %d (Itr : %d (%.3fs)), Relative_L2_error with numerical simulation : %.2E, Background_L2_error : %.2E ###"%(epoch, itr, one_itr_time, Relative_L2_error, Background_L2_error))
                print("##########################################")

                #### Save histoy: loss-forward, loss-inverse
                hist_forward.append(Relative_L2_error.item())
                hist_inverse.append(Background_L2_error.item())

                
    #### Training is over. save the final_results.

    interface = model.get_interface(boundary_disk)
    interface = interface.detach().cpu()

    dnn_int = dnn_int.cpu()
    dnn_ext = dnn_ext.cpu()

    if args.p_nn_flag == "True":

        inv_param = inv_param.cpu()

        torch.save({
                    'dnn_int_state_dict': dnn_int.state_dict(),
                    'dnn_ext_state_dict': dnn_ext.state_dict(),
                    'dnn_inv_param' : inv_param.state_dict(),
                    'Loss_hist' : model.hist,
                    'Inverse parameter hist' : model.hist_inv
                            }, PATH+'/final.bin')
        
    elif args.p_nn_flag == "False":

        inv_param = inv_param.detach().cpu()

        torch.save({
                    'dnn_int_state_dict': dnn_int.state_dict(),
                    'dnn_ext_state_dict': dnn_ext.state_dict(),
                    'Param_p' : inv_param,
                    'Loss_hist' : model.hist,
                    'Inverse parameter hist' : model.hist_inv
                            }, PATH+'/final.bin')        

    print('\nComputation time: {} seconds'.format(time()-t0))

    #### Save the collocation points.

    boundary_disk = boundary_disk.cpu()
    exterior_disk = exterior_disk.cpu()

    interior_psi = interior_psi.cpu()
    boundary_psi = boundary_psi.cpu()
    exterior_psi = exterior_psi.cpu()
    normal_psi = normal_psi.cpu()

    trace_int_psi = trace_int_psi.cpu()
    trace_ext_psi = trace_ext_psi.cpu()

    torch.save({
                'boundary_disk' : boundary_disk,
                'exterior_disk' : exterior_disk,
                'interior_psi' : interior_psi,
                'boundary_psi' : boundary_psi,
                'exterior_psi' : exterior_psi,
                'normal_psi' : normal_psi,
                'trace_int_psi' : trace_int_psi,
                'trace_ext_psi' : trace_ext_psi,
                'w' : w,
                'args' : args
                        }, PATH+'/samples.bin')


    after_learning = {
        "u_NN_exterior" :u_NN_exterior,
        "u_p_out" : u_p_out,
        "H" : H, 
        "boundary_disk_tspace" : boundary_disk_tspace,
        "interface" : interface, 
        "exterior_psi" : exterior_psi,
        "interior_psi" : interior_psi,
        "boundary_psi" : boundary_psi
    }
    
    learning_error = {
        "prediction_error" : prediction_error,
        "Relative_L2_error" : Relative_L2_error,
        "Background_L2_error" : Background_L2_error
    }

    if args.p_nn_flag == "True":

        torch.save({
                'boundary_disk_tspace' : boundary_disk_tspace,
                'p_nn_real' : p_nn_real,
                'p0' : p0,
                'Re_p' : Re_p

    }, PATH+'/p_nn_input_output.bin')
    

    ######## Save the history for forward-loss, inverse-loss.

    torch.save({
        'hist_forward' : hist_forward,
        'hist_inverse' : hist_inverse

    }, PATH+'/hist_forward_inverse.bin')


    ######## Save the after-learning tensors.

    torch.save({
        "u_NN_exterior" :u_NN_exterior,
        "u_p_out" : u_p_out,
        "H" : H,
        "exterior_psi" : exterior_psi,
        "interior_psi" : interior_psi,
        "boundary_psi" : boundary_psi
    }, PATH+'/after_learning.bin')
    



    return model, after_learning, learning_error