# Import libraries
import os

import numpy as np
import time
import math
import scipy.io
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

import argparse
import seaborn as sns
import torch

from PCVFR.model.model_utiities import rmse
from PCVFR.util.system_util import to_numpy, to_tensor
from PCVFR.model.network_architecture import HeatSirenNet, HeatTanhNet, HeatNetwork
from heat_utils import simulate_heat_equation, plot_snap

# Seeds
dtype = torch.float32
device = torch.device('cpu')

def main(problem_dict, solution_dict):

    # problem setting parameters
    x_min, x_max = problem_dict['MIN_X'], problem_dict['MAX_X'],
    t_min, t_max = problem_dict['MIN_T'], problem_dict['MAX_T']

    # grid for plotting
    num_x, num_t = solution_dict['N'], solution_dict['NUM_T']
    x = np.linspace(x_min, x_max, num_x) # Partitioned x axis
    y = np.linspace(x_min, x_max, num_x) # Partitioned y axis
    t = np.linspace(t_min, t_max, num_t) # Partitioned time axis
    t_grid, x_grid, y_grid = np.meshgrid(t, x, y, indexing='ij')
    txy_grid = np.stack((t_grid, x_grid, y_grid), -1)
    txy_test_numpy = txy_grid.reshape(-1,3)
    txy_test_tensor = to_tensor(txy_test_numpy, device=device)

    # analytical (numerical) solution
    T_test_idx = simulate_heat_equation(problem_dict, solution_dict)
    X, Y = np.meshgrid(x, y)
    nplots = 10
    if cargs['plots']:
        for t in np.linspace(0, num_t - 1, nplots, dtype='int'):
            plot_snap(T_test_idx, X, Y, t)

    T_temp = T_test_idx.T
    T_data = T_temp.reshape(-1,1)

    R = problem_dict['R']
    r = problem_dict['r']
    n_init = int(NUM_OBS_TRAIN * 0.3)
    n_in = int(NUM_OBS_TRAIN * 0.01)
    n_out = NUM_OBS_TRAIN - n_in - n_init
    c = 0.5 * (x_max + x_min)


    # information about source, which is ring shaped
    in_cond_inner = np.hypot(txy_test_numpy[:, 1] - c, txy_test_numpy[:, 2] - c) <= R
    in_cond_outer = np.hypot(txy_test_numpy[:, 1] - c, txy_test_numpy[:, 2] - c) >= r
    in_cond = in_cond_inner * in_cond_outer

    # information about initial condition, over the source
    init_time_cond = txy_test_numpy[:, 0] <= t_max / 50.
    init_cond = init_time_cond * in_cond

    in_cond = in_cond * ~init_cond # remove indeces from init_cond - which lie on the source by definition
    out_cond = ~in_cond * ~init_cond

    # extract indeces for each condition
    in_idx = np.where(in_cond)
    out_idx = np.where(out_cond)
    init_idx = np.where(init_cond)

    assert in_idx[0].shape[0] + out_idx[0].shape[0] + init_idx[0].shape[0] == txy_test_numpy.shape[0]

    # set random seed to use same training points across different configurations
    torch.manual_seed(123456)
    np.random.seed(123456)

    idx_train_in = np.random.choice(in_idx[0], n_in, replace=False)
    idx_train_out = np.random.choice(out_idx[0], n_out, replace=False)
    idx_train_init = np.random.choice(init_idx[0], n_init, replace=False)
    idx_train = np.concatenate((idx_train_in, idx_train_out, idx_train_init))
    assert np.unique(idx_train).all() == idx_train.all()
    indeces = np.linspace(0, T_data.shape[0]-1, T_data.shape[0], dtype='int')
    idx_test = np.setdiff1d(indeces, idx_train)

    # set aside validation data for early stopping
    if cargs['early_stopping']:
        n_valid = int(idx_test.shape[0] * 0.2)
        idx_valid = np.random.choice(idx_test, n_valid, replace=False)
        idx_test = np.setdiff1d(idx_test, idx_valid)
        assert idx_test.shape[0] + idx_train.shape[0] + idx_valid.shape[0] == T_data.shape[0]
        T_valid = T_data[idx_valid]
        T_valid_tensor = to_tensor(T_valid, device=device)
        txy_valid_np = txy_test_numpy[idx_valid]
        txy_valid_tensor = to_tensor(txy_valid_np, device=device)

    noise_std = 0.01
    T_train_temp = T_data[idx_train]
    T_train = T_train_temp + np.random.normal(0., noise_std, size=T_train_temp.shape)
    T_test = T_data[idx_test]

    T_train_tensor = to_tensor(T_train, device=device)
    T_test_tensor = to_tensor(T_test, device=device)

    txy_train_np = txy_test_numpy[idx_train]
    txy_test_np = txy_test_numpy[idx_test]

    txy_train_tensor = to_tensor(txy_train_np, device=device)
    txy_test_tensor = to_tensor(txy_test_np, device=device)

    if cargs['early_stopping']:
        assert T_train.shape[0] + T_test.shape[0] + T_valid.shape[0] == T_data.shape[0]
        assert txy_train_tensor.shape[0] + txy_test_tensor.shape[0] + txy_valid_tensor.shape[0] == txy_test_numpy.shape[0]
        tot_data = T_data.shape[0] / 100
        print(f"Dataset composition: {T_train.shape[0]/tot_data:.3f} train, {T_valid.shape[0]/tot_data:.3f} valid, {T_test.shape[0]/tot_data:.3f} test")
    else:
        assert T_train.shape[0] + T_test.shape[0] == T_data.shape[0]
        assert txy_train_tensor.shape[0] + txy_test_tensor.shape[0] == txy_test_numpy.shape[0]
        tot_data = T_data.shape[0] / 100
        print(f"Dataset composition: {T_train.shape[0] / tot_data:.3f} train, {T_test.shape[0] / tot_data:.3f} test")


    if cargs['plots']:
        dt = t_max / nplots
        n_plots = 10
        for t in np.linspace(0, t_max, nplots):
            if t == t_max: break
            fig = plt.figure()
            ax = fig.add_subplot(projection='3d')
            time_cond = np.abs(txy_train_np[:, 0] - t - dt * 0.5) < dt * 0.5
            time_ind = np.where(time_cond)
            samples_idx = txy_train_np[time_ind][:, 1:]
            ax.scatter(samples_idx[:, 0], samples_idx[:, 1], T_train_temp[time_ind], color='r')
            ax.scatter(samples_idx[:, 0], samples_idx[:, 1], T_train[time_ind], color='b')
            ax.set_zlim(0, 1)
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(x_min, x_max)
            plt.show()


    # save data to file
    DATA_PATH = "./data/heat_equation/"
    os.makedirs(DATA_PATH, exist_ok=True)

    df_train_data = pd.DataFrame(txy_train_np, columns = ['t','x','y'])
    df_train_data['T'] = T_train.ravel()

    df_valid_data = pd.DataFrame(txy_valid_np, columns = ['t','x','y'])
    df_valid_data['T'] = T_valid.ravel()

    df_test_data = pd.DataFrame(txy_test_np, columns = ['t','x','y'])
    df_test_data['T'] = T_test.ravel()


    df_train_data.to_csv(DATA_PATH+'train.csv')
    df_valid_data.to_csv(DATA_PATH+'valid.csv')
    df_test_data.to_csv(DATA_PATH+'test.csv')

    MODEL_PATH = "./models/"

    obs_train = dict(x=txy_train_tensor,
                     T=T_train_tensor
                     )

    true_test_values = dict(T=T_test)
    if cargs['early_stopping']: true_valid_values = dict(T=T_valid)

    loss_history = []
    val_r2_history = []
    val_rmse_history = []

    step = 100
    best_r2 = -100
    if TRAIN:
        # set random seed for torch model
        torch.manual_seed(SEED)
        np.random.seed(SEED)

        # Initialize neural network
        if cargs['activation'] == 'tanh':
            model = HeatTanhNet(problem_dict, num_units=NUM_UNITS, num_layers=NUM_LAYERS, dropout=DROPOUT, device=device).to(device, dtype=dtype)
        elif cargs['activation'] == 'siren':
            model = HeatSirenNet(problem_dict, num_units=NUM_UNITS, num_layers=NUM_LAYERS, dropout=DROPOUT, device=device).to(device, dtype=dtype)
        else:
            raise ValueError('Activation must be one of tanh or siren')

        # Loss and optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=.9999)

        print('Start training...')
        tic = time.time()
        pde_samples = model.sample_signal_domain(num_samples=NUM_SAMPLES, fraction_mcmc=-1)

        try:
            for epoch in range(1, NUM_EPOCHS + 1):
                avg = []
                model.train()
                optimizer.zero_grad()  # Optimizer

                loss_pde = model.loss_pde(pde_samples)  # Loss function of PDE
                loss_obs = model.loss_ic(**obs_train)  # Loss function of obs

                if cargs['nopde']:
                    loss = loss_obs
                else:
                    loss = cargs['w_pde'] * loss_pde + loss_obs # Total loss function G(theta)

                # update weights
                loss.backward()
                optimizer.step()

                # track losses
                avg.append(loss.item())

                # update learning rate
                cur_lr = optimizer.param_groups[0]['lr']
                if cur_lr > LR_MIN:
                    scheduler.step()

                model.eval()
                # resample
                if epoch > 0 and epoch % 500 == 0:
                    pde_samples = model.sample_signal_domain(NUM_SAMPLES, fraction_mcmc=MCMC_SAMPLE * FRAC_MCMC_SAMPLE, burnin=1000)

                # log losses
                if epoch % step == 0:

                    if cargs['early_stopping']:
                        true_values = dict(T=true_valid_values['T'])
                        predicted_values = predict_to_grid(model, txy_valid_tensor)
                    else:
                        true_values = dict(T=true_test_values['T'])
                        predicted_values = predict_to_grid(model, txy_test_tensor)

                    RMSEs = {key: rmse(predicted_values[key], true_values[key])
                             for key in true_values.keys()}
                    r2 = r2_score(true_values['T'], predicted_values['T'])
                    loss_history.append(loss.item())
                    val_r2_history.append(r2)
                    val_rmse_history.append(RMSEs['T'])

                    # early_stopping
                    if cargs['early_stopping']:
                        if r2 > best_r2:
                            e = 0
                            best_r2 = r2
                            # best_model = model
                            print('->model saved<-')
                            torch.save(model.state_dict(), MODEL_PATH + MODEL_NAME + '.model')
                        else:
                            e += 1
                            if e > cargs['early_stopping']:
                                break

                    if epoch % 500 == 0:
                        loss_value = sum(avg) / len(avg)
                        print(
                            f'epoch {epoch} loss_pde:{loss_pde:.2e}, \tloss_obs:{loss_obs:.2e}, \tloss:{loss_value:.2e}')
                        print(f"RMSE Test - T:{RMSEs['T']:.2e}")
                        print(f"R2 Test - T:{r2:.2e}")


        except KeyboardInterrupt:
            print(f"Aborted training at epoch {epoch}.. continue with plots")
        toc = time.time()
        print(f'Total training time: {toc - tic}')

        # dictionary of lists  
        history_dict = {'loss': loss_history, 'rmse': val_rmse_history, 'r2': val_r2_history}  
       
        df = pd.DataFrame(history_dict)
        # saving the dataframe 
        df.to_csv(RESULTS_PATH+MODEL_NAME+'.csv') 

        if cargs['early_stopping']:
            if cargs['activation'] == 'tanh':
                best_model = HeatTanhNet(problem_dict, num_units=NUM_UNITS, num_layers=NUM_LAYERS, dropout=DROPOUT, device=device).to(device, dtype=dtype)
                best_model.load_state_dict(torch.load(MODEL_PATH + MODEL_NAME + '.model'))
                best_model.eval()
            elif cargs['activation'] == 'siren':
                best_model = HeatSirenNet(problem_dict, num_units=NUM_UNITS, num_layers=NUM_LAYERS, dropout=DROPOUT, device=device).to(device, dtype=dtype)
                best_model.load_state_dict(torch.load(MODEL_PATH + MODEL_NAME + '.model'))
                best_model.eval()
        else:
            best_model = model

        if cargs['plots']:
            length = len(loss_history)
            plt.plot(np.linspace(step,step * length, length), loss_history, label='loss')
            plt.plot(np.linspace(step,step * length, length), val_r2_history, label='r2')
            plt.plot(np.linspace(step,step * length, length), val_rmse_history, label='rmse')
            plt.title(MODEL_NAME)
            plt.legend()
            plt.show()

        preds = best_model(to_tensor(txy_test_numpy, device=device))
        preds = to_numpy(preds).reshape(t_grid.shape)
        pred_err = np.abs(T_temp - preds)
        X, Y = np.meshgrid(x, y)
        nplots = 10
        if cargs['plots']:
            for t in np.linspace(0, num_t - 1, nplots, dtype='int'):
                plot_surface(X, Y, pred_err, t)
                plot_surface(X, Y, preds, t)

        predicted_test_values = predict_to_grid(best_model, txy_test_tensor)
        error_T = np.abs(predicted_test_values['T'] - true_test_values['T'])
        print(f"ABS ERROR: {error_T.mean():.2e}")
        print(f"RMSE ERROR: {rmse(predicted_test_values['T'], true_test_values['T']):.2e}")
        print(f"R2 ERROR: {r2_score(true_test_values['T'], predicted_test_values['T']):.2e}")
        
        f = open(RESULTS_PATH+MODEL_NAME+".txt", "w")
        f.write(f"ABS:{error_T.mean():.10f}\n")
        f.write(f"RMSE:{rmse(predicted_test_values['T'], true_test_values['T']):.10f}\n")
        f.write(f"R2:{r2_score(true_test_values['T'], predicted_test_values['T']):.10f}\n")
        f.close()

    else:
        best_model = HeatTanhNet(problem_dict, num_units=32, num_layers=2, device=device).to(device, dtype=dtype)
        MODEL_PATH = "./results_paper/models/"
        best_model.load_state_dict(torch.load(MODEL_PATH + MODEL_NAME + '.model'))
        predicted_test_values = predict_to_grid(best_model, txy_test_tensor)

        # set seaborn style
        sns.set_style("white")
        preds = best_model(to_tensor(txy_test_numpy, device=device))

        preds_plot = to_numpy(preds)
        T_plot = T_temp.reshape(-1, 1, order='C')
        txy_test_plot = txy_test_numpy

        offset = MAX_T / nplots / 20

        for t in np.linspace(0, MAX_T, nplots):
            plt.figure(figsize=(10,8))
            time_mask = (txy_test_plot[:, 0] <= (t + offset)) * (txy_test_plot[:, 0] >= (t - offset))
            sc = plt.scatter(x=txy_test_plot[time_mask][:, 1], y=txy_test_plot[time_mask][:, 2], cmap='Blues',
                            c=preds_plot[time_mask])
            plt.xlim(MIN_X, MAX_X)
            plt.ylim(MIN_X, MAX_X)
            plt.yticks(fontsize=18)
            plt.xticks(fontsize=18)
            plt.tight_layout()
            if cargs['mcmc']:
                cbar = plt.colorbar(sc, ticks=np.linspace(0, 1, 6, endpoint=True))
                cbar.ax.tick_params(labelsize=18)
                plt.savefig(f"predicted_temperature_mcmc_{t:.3f}.png", dpi=100)
                plt.gca().set_aspect('equal', adjustable='box')
                plt.close()
            else:
                plt.savefig(f"predicted_temperature_unif_{t:.3f}.png", dpi=100)
                plt.close()

            plt.figure(figsize=(8, 8))
            sns.scatterplot(x=txy_test_plot[time_mask][:, 1], y=txy_test_plot[time_mask][:, 2], cmap='Blues',
                c=T_plot[time_mask])
            plt.xlim(MIN_X, MAX_X)
            plt.ylim(MIN_X, MAX_X)
            plt.yticks(fontsize=18)
            plt.xticks(fontsize=18)
            plt.tight_layout()

            plt.savefig(f"true_temperature_{t:.3f}.png", dpi=200)


    df_pred = pd.DataFrame({key: val.reshape(-1) for key, val in predicted_test_values.items()})

def plot_surface(X, Y, T, idx):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(X, Y, T[idx], cmap='gist_rainbow', edgecolor='none')
    ax.set_zlim(0, T.max())
    ax.set_xlabel('X [m]')
    ax.set_ylabel('Y [m]')
    ax.set_zlabel('T [°]')
    plt.show()


def plot_contour(X, Y, T, idx, problem_dict):
    min_x, max_x = problem_dict['MIN_X'], problem_dict['MAX_X']
    fig = plt.figure()
    contours = plt.contour(X, Y, T[idx])#, vmin=0., vmax=T.max())
    plt.clabel(contours, inline=True, fontsize=8)
    plt.imshow(T[idx], extent=[min_x, max_x, min_x, max_x],
               origin='lower', alpha=1)
    plt.colorbar()
    plt.show()

def plot_contour_samples(X, Y, T, idx, samples, problem_dict, solution_dict):
    min_x, max_x = problem_dict['MIN_X'], problem_dict['MAX_X']
    fig = plt.figure()
    contours = plt.contour(X, Y, T[idx])#, vmin=0., vmax=T.max())
    plt.clabel(contours, inline=True, fontsize=8)
    plt.imshow(T[idx], extent=[min_x, max_x, min_x, max_x],
               origin='lower', alpha=1)
    plt.colorbar()
    np_samples = samples.cpu().detach().numpy()
    print('np_samples', np_samples.shape, np_samples)
    num_t, max_t = T.shape[0], problem_dict['MAX_T']
    assert num_t == solution_dict['NUM_T']
    dt = max_t/num_t
    time_cond = np.abs(np_samples[:,0] - idx) < dt * 0.5
    samples_idx = np_samples[np.where(time_cond)][:,1:]
    print('samples_idx', samples_idx.shape, samples_idx)
    plt.scatter(samples_idx[:,0],samples_idx[:,1], marker='X', color='r')
    plt.show()


def predict_to_grid(model: HeatNetwork, txy_test_tensor: torch.Tensor) -> dict:
    """
    Predict on grid of test samples and returns dictionary with values on the grid.
    Args:
        model ():
        t_grid ():
        tx_test_tensor ():
        x_grid ():

    Returns:
        predicted_test_values
    """

    pred_tensor = model(txy_test_tensor)
    predicted_test_values = dict(
        T=to_numpy(pred_tensor),
    )

    return predicted_test_values


def sample_initial_conditions(num_ic_train, problem_dict):
    """
    Samples x positions at time t=0 that are used as "training points" for the initial condition.
    x is sampled uniformly within min_x and max_x.
    Args:
        num_i_train ():
        min_x ():
        max_x ():
        shocktube_settings ():

    Returns:

    """

    x_min, x_max = problem_dict['MIN_X'], problem_dict['MAX_X']

    x_ic = np.random.uniform(x_min, x_max, num_ic_train)
    y_ic = np.random.uniform(x_min, x_max, num_ic_train)
    t_ic = np.repeat(0., num_ic_train)
    txy_ic = np.stack((t_ic, x_ic, y_ic), -1)

    T_ic_train = to_tensor(initial_conditions(txy_ic, problem_dict), device=device)

    return dict(T=T_ic_train, x=to_tensor(txy_ic, device=device))


def initial_conditions(x, problem_dict):
    """
    Helper function to obtain initial condition at different x positions.
    Args:
        x ():
        shocktube_settings ():

    Returns:

    """
    #assert x.shape[-1] == 2

    R = problem_dict['R']  # radius of heat source
    T0 = problem_dict['T0']

    x_min, x_max = problem_dict['MIN_X'], problem_dict['MAX_X']
    center = 0.5 * (x_min + x_max)
    disc_eq = (x[:,1]-center)**2 + (x[:,2]-center)**2 < R**2
    disc_idx = np.where(disc_eq)
    #print('disc_idx.shape', disc_idx.shape)

    assert x.shape[-1] == 3

    T = np.zeros((x.shape[0], 1))
    T[disc_idx] = T0

    return T



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", help="Device to use", nargs='?',
                        choices=('cpu', 'cuda'),
                        default="cpu")
    parser.add_argument("--activation", help="activation function", nargs='?',
                        choices=('siren', 'tanh'),
                        default="tanh")
    parser.add_argument("--lr", default=5e-4, type=float)
    parser.add_argument("--lr-min", default=5e-5, type=float)
    parser.add_argument("--w-pde", default=0.05, type=float)
    parser.add_argument("--num-obs-train", default=1000, type=int)
    parser.add_argument("--num-samples", default=128, type=int)
    parser.add_argument("--early-stopping", default=30, type=int)
    parser.add_argument("--num-epochs", default=15_000, type=int)
    parser.add_argument("--seed", default=1, type=int)
    parser.add_argument("--dropout", default=0, type=float)
    parser.add_argument("-predict", action="store_true", default=False)
    parser.add_argument("-mcmc", action="store_true", default=False)
    parser.add_argument("-plots", action="store_true", default=False)
    parser.add_argument("-nopde", action="store_true", default=False)

    cargs = vars(parser.parse_args())

    # Initialization
    LR = cargs["lr"]  # Learning rate
    LR_MIN = cargs["lr_min"]

    NUM_IC_TRAIN = 10_000
    NUM_OBS_TRAIN = cargs["num_obs_train"] # Number of points for initial conditions
    NUM_EPOCHS = cargs["num_epochs"]  # Number of maximum iterations
    NUM_SAMPLES = cargs['num_samples']  # Random sampled points for pde loss

    # problem setup
    T0 = 1 # temperature of source
    K = 0.1 #1.172E-5 # steel, 1% carbon
    MIN_X, MAX_X = -2.5, 2.5
    MIN_T, MAX_T = 0., 0.2 #100
    L = MAX_X - MIN_X

    R = L / 3
    r = L / 4

    problem_dict = dict(T0=T0, K=K, MIN_T=MIN_T, MAX_T=MAX_T,
                        MIN_X=MIN_X, MAX_X=MAX_X, L=L, R=R, r=r)

    # exact solution
    N = 157 # number of point along each axis
    DT = 0.002 # time step
    NUM_T = int (MAX_T / DT)
    RS = int(R / L * N) # radius, but in terms of indeces --> int
    rs = int(r / L * N)  # radius, but in terms of indeces --> int

    assert N%2 == 1
    assert (((N-1)/2-1)% 2) == 1

    dx = L / N
    cn = K * DT / dx
    print(cn)
    if cn > 0.5: raise TypeError('Unstable Solution!')

    solution_dict = dict(N=N, NUM_T=NUM_T, R=RS, r=rs)

    W_PDE = cargs['w_pde']

    NUM_UNITS = 32
    NUM_LAYERS = 2

    TRAIN = not cargs["predict"]
    MCMC_SAMPLE = cargs["mcmc"]
    FRAC_MCMC_SAMPLE = .9

    MODEL_PATH = "./models/"
    os.makedirs(MODEL_PATH, exist_ok=True)

    RESULTS_PATH = "./results/"
    os.makedirs(RESULTS_PATH, exist_ok=True)

    if cargs['nopde']: sampling='nopde'
    elif cargs["mcmc"]: sampling='mcmc'
    else: sampling='uniform'

    DROPOUT = cargs['dropout']

    SEED = cargs['seed']
    ACTIVATION = cargs['activation']

    MODEL_NAME = f"{ACTIVATION}_{sampling}_{W_PDE}_{NUM_OBS_TRAIN}_{NUM_SAMPLES}_{SEED}"
    print(f'MODEL_NAME: {MODEL_NAME}')

    main(problem_dict, solution_dict)

