import argparse
import logging
import traceback

import math
import numpy as np
import os
import torch
import torch.nn as nn

from matplotlib import cm
import matplotlib.pyplot as plt

from datetime import datetime
import pandas as pd

import utils
from models import MLP



def true_sol(pts):
    if isinstance(pts, torch.Tensor):
        output = torch.exp(-0.25*pts[:,0])*torch.sin(pts[:,1]-pts[:,0])
        return output.reshape(-1,1)
    elif isinstance(pts, np.ndarray):
        output = np.exp(-0.25*pts[:,0])*np.sin(pts[:,1]-pts[:,0])
        return output.reshape(-1,1)
    
def source_f(pts):
    if isinstance(pts, torch.Tensor):
        return torch.zeros_like(pts[:,0])
    elif isinstance(pts, np.ndarray):
        return np.zeros_like(pts[:,0])
    
def generate_interior_points(num_o, random=False, tensor=False, device='cpu'):
    num_ox, num_oy = utils.closest_factors(num_o)
    if random: # random sample
        xs = np.random.random(size=num_o) # [0,1]
        ys = np.pi*np.random.random(size=num_o) # [0,1]
        pts_o = np.stack([xs,ys], axis=1)
    else: # uniform sample
        xs = np.linspace(0,1, num_ox+2)[1:-1]
        ys = np.linspace(0,np.pi, num_oy+2)[1:-1]
        xs, ys = np.meshgrid(xs, ys)
        pts_o = np.stack([xs.flatten(), ys.flatten()], axis=1)

    if tensor:
        pts_o = torch.tensor(pts_o, dtype=torch.float32, requires_grad=True).to(device)

    return pts_o
    
def generate_boundary_points(num_b, random=False, tensor=False, device='cpu'):
    pts_b1 = np.zeros(shape=(num_b,2))
    pts_b2 = np.zeros(shape=(num_b,2))
    pts_b3 = np.zeros(shape=(num_b,2))
    
    if random: # random sample
        pts_b1[:,1] = np.pi*np.random.random(size=num_b)  
        pts_b2[:,0] = np.random.random(size=num_b)
        pts_b3[:,0] = np.random.random(size=num_b)
        pts_b3[:,1] = np.pi
    else: # uniform sample
        pts_b1[:,1] = np.linspace(0,1,num_b+2)[1:-1]*np.pi  # (0,y)
        pts_b2[:,0] = np.linspace(0,1,num_b+2)[1:-1] # (x,0)
        pts_b3[:,0] = np.linspace(0,1,num_b+2)[1:-1] # (x, pi)
        pts_b3[:,1] = np.pi
        
    if tensor:
        pts_b1 = torch.tensor(pts_b1, dtype=torch.float32, requires_grad=True).to(device)
        pts_b2 = torch.tensor(pts_b2, dtype=torch.float32, requires_grad=True).to(device)
        pts_b3 = torch.tensor(pts_b3, dtype=torch.float32, requires_grad=True).to(device)

    return pts_b1, pts_b2, pts_b3

    
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=5884, help='random seed')
    parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: 0]')
    parser.add_argument('--random', type=bool, default=False)
    parser.add_argument('--no', type=int, default=300)
    parser.add_argument('--nb', type=int, default=100)
    
    parser.add_argument('--m', type=int, default=1000)
    parser.add_argument('--activation', type=str, default='relu')
    parser.add_argument('--p', type=int, default=2)
    parser.add_argument('--D', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--epochs', type=int, default=400000)
    parser.add_argument('--val_freq', type=int, default=200)
    parser.add_argument('--optimizer', type=str, default='SGD')
    args = parser.parse_args()
    return args
    
    
def main():
    args = parse_args()
    utils.set_random_seed(args.seed)
    
    m = args.m
    D = args.D
    device = f'cuda:{args.gpu}'
    if args.activation.lower()=='relu':
        activation = nn.ReLU()
    elif args.activation.lower()=='tanh':
        activation = nn.Tanh()
    ps = [args.p]
    
    if args.p:
        save_dir = f'./convection_diffusion/{args.optimizer}/split(m{args.m})/{args.activation}(p{args.p})/{args.seed}(lr{args.lr})'
    else:
        save_dir = f'./convection_diffusion/{args.optimizer}/split(m{args.m})/{args.activation}/{args.seed}(lr{args.lr})'
    save_dir_temp = save_dir + '(doing)'
    utils.mkdir(save_dir_temp)
    utils.save_configs(save_dir_temp, vars(args))
    
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    logger.addHandler(stream_handler)
    file_handler = logging.FileHandler(filename=os.path.join(save_dir_temp,'results.log'),
                                       mode='w', encoding='utf-8')
    logger.addHandler(file_handler)
    logger.info(args)
    
    random = args.random
    num_o = args.no
    num_b = args.nb
    
    pts_o = generate_interior_points(num_o, random=False)
    pts_b1, pts_b2, pts_b3 = generate_boundary_points(num_b, random=False)
    fs_o = source_f(pts_o)
    
    pts_o = torch.tensor(pts_o, dtype=torch.float32, requires_grad=True).to(device)
    pts_b1 = torch.tensor(pts_b1, dtype=torch.float32, requires_grad=True).to(device)
    pts_b2 = torch.tensor(pts_b2, dtype=torch.float32, requires_grad=True).to(device)
    pts_b3 = torch.tensor(pts_b3, dtype=torch.float32, requires_grad=True).to(device)
    srcs_o = torch.tensor(fs_o, dtype=torch.float32).to(device)
    
    plot_xs = np.linspace(0,1, 101)
    plot_ys = np.linspace(0,np.pi, 101)
    plot_xs, plot_ys = np.meshgrid(plot_xs, plot_ys)
    plot_pts = np.stack([plot_xs.flatten(),plot_ys.flatten()], axis=1)
    plot_pts = torch.tensor(plot_pts, dtype=torch.float32).to(device)
    
    
    epochs = args.epochs
    val_freq = args.val_freq
    
    PINN_loss_results = dict()
    L2_loss_results = dict()

    df_losses = pd.DataFrame()
    df_errors = pd.DataFrame()
    start = datetime.now()
    logger.info(f"Case: m={m}\t({start.strftime('%y.%m.%d-%H:%M:%S')})")
    for p in ps:
        logger.info(f'\tp={p}')
        try:
            del model_1
        except:
            pass
        model_1 = MLP(m, d=2, d_out=1, p=p, D=D, activation=activation) # u
        model_2 = MLP(m, d=2, d_out=1, p=p, D=2, activation=activation) # u_x
        model_1.to(device)
        model_2.to(device)
        model_1.train()
        model_2.train()
        
        PINN_losses = []
        L2_losses = []

        if args.optimizer.lower()=='sgd':
            optimizer1 = torch.optim.SGD(params=model_1.parameters(), lr=args.lr)
            optimizer2 = torch.optim.SGD(params=model_2.parameters(), lr=args.lr)
        elif args.optimizer.lower()=='adam':
            optimizer1 = torch.optim.Adam(params=model_1.parameters(), lr=args.lr)
            optimizer2 = torch.optim.Adam(params=model_2.parameters(), lr=args.lr)
        
        best_epoch = 0
        for epoch in range(epochs+1):
            if random:
                del pts_o, pts_b1, pts_b2, pts_b3, srcs_o
                pts_o = generate_interior_points(num_o, random=random, tensor=True, device=device)
                pts_b1, pts_b2, pts_b3 = generate_boundary_points(num_b, random=random, tensor=True, device=device)
                srcs_o = source_f(pts_o).to(device)
                
            # interior PDE loss
            out_o1 = model_1(pts_o).squeeze()
            grad_o = torch.autograd.grad(out_o1, pts_o, grad_outputs=torch.ones_like(out_o1), create_graph=True)[0]
            u_t = grad_o[:,0]
            u_x = grad_o[:,1]

            out_o2 = model_2(pts_o).squeeze()
            u_xx = torch.autograd.grad(out_o2, pts_o, grad_outputs=torch.ones_like(out_o2), create_graph=True)[0][:,1]
            
            out_o = u_t + u_x - 0.25*u_xx
            loss_o = torch.square(out_o-srcs_o).sum() / len(pts_o)

            # interiot gradient matching loss
            loss_gm = torch.square(u_x-out_o2).sum() / len(pts_o)

            # boundary loss
            # (0,x)
            out_b1 = model_1(pts_b1).squeeze()
            loss_b1 = torch.square(out_b1 - true_sol(pts_b1).squeeze()).sum()
            # (t,0)
            out_b2 = model_1(pts_b2).squeeze()
            loss_b2 = torch.square(out_b2 - true_sol(pts_b2).squeeze()).sum()
            # (t,pi)
            out_b3 = model_1(pts_b3).squeeze()
            loss_b3 = torch.square(out_b3 - true_sol(pts_b3).squeeze()).sum()
            
            loss_b = (loss_b1 + loss_b2 + loss_b3) / (len(pts_b1) + len(pts_b2) + len(pts_b3))
            
            loss = loss_o + loss_gm + 10*loss_b
            
            if epoch % val_freq==0:
                if PINN_losses and loss.item() < PINN_losses[-1]:
                    try:
                        os.remove(os.path.join(save_dir_temp, f'PINN_model_1({best_epoch}).pth'))
                    except:
                        pass
                    best_epoch = epoch
                    torch.save(model_1.state_dict(), os.path.join(save_dir_temp, f'PINN_model_1({best_epoch}).pth'))
                    
                    with torch.no_grad():
                        fig, axs = plt.subplots(1,2, subplot_kw={"projection": "3d"}, figsize=(10*2,10))
                        X = plot_pts[:,0].cpu().numpy()
                        Y = plot_pts[:,1].cpu().numpy()
                        Z = model_1(plot_pts).squeeze().cpu().numpy()
                        Z_true = true_sol(plot_pts).cpu().numpy()
                        axs[0].plot_surface(X.reshape(101,101), Y.reshape(101,101), Z_true.reshape(101,101),
                                            cmap=cm.coolwarm, linewidth=0, antialiased=False)
                        
                        axs[1].plot_surface(X.reshape(101,101), Y.reshape(101,101), Z.reshape(101,101), 
                                            cmap=cm.coolwarm, linewidth=0, antialiased=False)
                        axs[0].set_title('True solution')
                        axs[1].set_title(f'Predicted solution\n(MSE={np.square(Z-Z_true).mean()})')
                        
                        plt.savefig(fname=os.path.join(save_dir_temp, 'result.png'))
                        plt.close()
                    
                    
                    
                PINN_losses.append(loss.item())
                logger.info(f"epoch:{epoch}\t{loss.item()}\t={loss_o.item()}\t+{loss_gm.item()}\t+10*{loss_b.item()}")
                    

            optimizer1.zero_grad()
            optimizer2.zero_grad()
            loss.backward()
            optimizer1.step()
            optimizer2.step()
            
            del out_o1, out_o2, grad_o, u_t, u_x, u_xx, loss_o
            del loss_gm
            del out_b1, loss_b1
            del out_b2, loss_b2
            del out_b3, loss_b3
            del loss_b, loss
            
            if epoch % val_freq==0:
                with torch.no_grad():
                    out_o = model_1(pts_o).squeeze()
                    loss_o = torch.sqrt(torch.square(out_o-true_sol(pts_o)).sum() * (np.square(np.pi) / len(pts_o) ))
                    L2_losses.append(loss_o.item())
                    del out_o, loss_o
            
        PINN_loss_results.update({f'({m},{p})': PINN_losses})
        L2_loss_results.update({f'({m},{p})': L2_losses})
        df_losses[p] = PINN_losses
        df_errors[p] = L2_losses
    
    df_losses = df_losses.transpose()
    df_errors = df_errors.transpose()
    df_losses.to_csv(os.path.join(save_dir_temp,f'losses.csv'), index=True)
    df_errors.to_csv(os.path.join(save_dir_temp,f'errors.csv'), index=True)
   
    del start
    logger.info('Finished')
    os.rename(save_dir_temp, save_dir)

if __name__ == '__main__':
    try:
        main()
    except:
        logging.error(traceback.format_exc())
        args = parse_args()
        if args.p:
            save_dir = f'./convection_diffusion/{args.optimizer}/split(m{args.m})/{args.activation}(p{args.p})/{args.seed}(lr{args.lr})'
        else:
            save_dir = f'./convection_diffusion/{args.optimizer}/split(m{args.m})/{args.activation}/{args.seed}(lr{args.lr})'
        save_dir_temp = save_dir + '(doing)'
        os.rename(save_dir_temp, save_dir+'(error)')