import sys  # nopep8
sys.path.append('.')  # nopep8

import os
import argparse
import time
import numpy as np
import yaml
from tqdm import tqdm
import time
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from core.lie_neurons_layers import *
from core.lie_alg_util import *
from core.lie_group_util import *
from experiment.euler_poincare_eq_layers import *


parser = argparse.ArgumentParser('Euler Poincare Equation Fitting')
parser.add_argument('--method', type=str, default='dopri5')
parser.add_argument('--no_quantitative', action='store_true')
parser.add_argument('--no_test_augmentation', action='store_true')
parser.add_argument('--num_testing', type=int, default=10)
parser.add_argument('--num_testing_augmentation', type=int, default=10)
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--niters', type=int, default=2000)
parser.add_argument('--test_freq', type=int, default=10)
parser.add_argument('--save_freq', type=int, default=100)
parser.add_argument('--viz', action='store_true')
parser.add_argument('--viz_time', type=int, default=15)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--adjoint', action='store_true')
parser.add_argument('--inertia_type', type=str, choices=['iss','model1'], default='iss')
parser.add_argument('--fig_save_path', type=str, default='figures/euler_poincare_ln')
parser.add_argument('--model_load_path', type=str, default='weights/ode4_new_training_best_test_loss_acc.pt')
parser.add_argument('--model_type', type=str, default='neural_ode')
parser.add_argument('--log_writer_path', type=str, default='logs/euler_poincare_ln')
args = parser.parse_args()

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint


device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')


torch.manual_seed(81292)
if args.num_testing > 1:
    testing_y0 = torch.rand((args.num_testing, 3)).to(device)
elif args.num_testing == 1:
    testing_y0 = torch.tensor([[2., 1.,3.0]]).to(device)

vis_y0 = torch.tensor([[2., 1.,3.0]]).to(device)

t_end_list = [5., 10., 15., 20., 25.]
t = []
for t_end in t_end_list:
    t.append(torch.linspace(0., t_end, int(args.data_size/25.*t_end)).to(device))
vis_t = torch.linspace(0., args.viz_time, args.data_size).to(device)

class EulerPoincareEquation(nn.Module):
    
    def __init__(self) -> None:
        super().__init__()
        '''
        Inertia matrix of the ISS
        https://athena.ecs.csus.edu/~grandajj/ME296M/space.pdf
        page 7-62
        '''
        if args.inertia_type == 'iss':
            self.I = torch.Tensor([[5410880., -246595., 2967671.],[-246595., 29457838., -47804.],[2967671., -47804., 26744180.]]).unsqueeze(0).to(device)
        elif args.inertia_type == 'model1':
            self.I = torch.Tensor([[12, -5., 7.],[-5., 20., -2.],[7., -2., 5.]]).unsqueeze(0).to(device)

        self.I_inv = torch.inverse(self.I)
        self.hat_layer = HatLayer(algebra_type='so3').to(device)

        self.ufunc = lambda t: 10*torch.tensor([torch.sin(t),torch.cos(t),torch.sin(t)]).to(device)

    def forward(self,t,w):
        '''
        w: angular velocity (B,3) or (1,3)
        '''
        w_v = w.unsqueeze(2)
        # print("t: ",t,"u: ",self.ufunc(t))
        # print("I_inv*u:", torch.matmul(self.I_inv,self.ufunc(t)))
        # print("value before: ",-torch.matmul(self.I_inv,torch.matmul(self.hat_layer(w),torch.matmul(self.I,w_v))).squeeze(2))
        # print("value: ",-torch.matmul(self.I_inv,torch.matmul(self.hat_layer(w),torch.matmul(self.I,w_v))).squeeze(2)\
        #         +torch.matmul(self.I_inv,self.ufunc(t)) )
        return -torch.matmul(self.I_inv,torch.matmul(self.hat_layer(w),torch.matmul(self.I,w_v))).squeeze(2)\
                +torch.matmul(self.I_inv,self.ufunc(t)) 


''' 
N: number of testing data
J: number of end time being tested (5)
M_j: data size for j-th time
D: dimension (3)
'''
with torch.no_grad():
    testing_y = []
    # print("testing y0", testing_y0.shape)
    for j, t_j in enumerate(t):
        # for i in range(args.num_testing):
        true_y = odeint(EulerPoincareEquation(), testing_y0, t_j, method='dopri5').unsqueeze(-2)    # (M_j, N, 1, D)
        # print("true_y", true_y.shape)
        testing_y.append(true_y) # (J, (M_j, N, 1, D))

    vis_y = odeint(EulerPoincareEquation(), vis_y0, vis_t, method='dopri5')
    


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


# if args.viz:
    # makedirs(args.fig_save_path)
    # fig = plt.figure(figsize=(12, 4), facecolor='white')
    # ax_traj = fig.add_subplot(131, frameon=False)
    # ax_phase = fig.add_subplot(132, frameon=False)
    # ax_vecfield = fig.add_subplot(133, frameon=False)
    # plt.show(block=False)


def visualize(true_y, pred_y, odefunc, itr):
    makedirs(args.fig_save_path)
    # plt.rcParams['text.usetex'] = True

    plt.rcParams.update({'font.size': 22})
    fig1 = plt.figure(1)
    # fig1.cla()
    
    plt.plot(vis_t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], '-', color='royalblue', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', color='orange', linewidth=2)
    plt.xlim(vis_t.cpu().min(), vis_t.cpu().max())
    plt.title('Trajectories (x)')
    plt.xlabel('t')
    plt.ylabel('x')
    plt.legend(['Ground Truth', 'Lie Neurons (x)'])
    # ax_traj.set_ylim(-2, 2)
    # fig1.legend()

    fig2 = plt.figure(2)
    plt.plot(vis_t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], '-',color='royalblue', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], '--', color='orange', linewidth=2)
    plt.xlim(vis_t.cpu().min(), vis_t.cpu().max())
    plt.title('Trajectories (y)')
    plt.xlabel('t')
    plt.ylabel('y')
    plt.legend(['True y', 'Predicted y'])

    fig3 = plt.figure(3)
    plt.plot(vis_t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 2], '-', color='royalblue', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 2], '--',color='orange', linewidth=2)
    plt.xlim(vis_t.cpu().min(), vis_t.cpu().max())
    plt.title('Trajectories (z)')
    plt.xlabel('t')
    plt.ylabel('z')
    plt.legend(['True z', 'Predicted z'])

    fig4 = plt.figure(4)
    plt.plot(vis_t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], '-', color='royalblue', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], '-', color='darkred', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 2], '-', color='darkviolet', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--',color='orange', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], '--',color='mediumseagreen', linewidth=2)
    plt.plot(vis_t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 2], '--',color='gold', linewidth=2)
    plt.xlim(vis_t.cpu().min(), vis_t.cpu().max())
    plt.title('Trajectories')
    plt.xlabel('t')
    plt.ylabel('x, y, z')
    plt.legend(['Ground Truth x', 'Ground Truth y', 'Ground Truth z', 'Lie Neurons x', 'Lie Neurons y', 'Lie Neurons z'])

    ax = plt.figure(5).add_subplot(projection='3d')
    ax.plot(true_y.cpu().numpy()[:, 0, 0],true_y.cpu().numpy()[:, 0, 1],true_y.cpu().numpy()[:, 0, 2], '-', color='royalblue')
    ax.plot(pred_y.cpu().numpy()[:, 0, 0],pred_y.cpu().numpy()[:, 0, 1],pred_y.cpu().numpy()[:, 0, 2], '-', color='orange')
    plt.title('Trajectories (3D)')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.legend(['Ground Truth', 'Lie Neurons'])
    # fig = plt.figure(figsize=(12, 4), facecolor='white')
    # ax_traj = fig.add_subplot(131, frameon=False)
    # ax_phase = fig.add_subplot(132, frameon=False)
    # ax_vecfield = fig.add_subplot(133, frameon=False)
    # plt.show(block=False)

    # ax_phase.cla()
    # ax_phase.set_title('Phase Portrait')
    # ax_phase.set_xlabel('x')
    # ax_phase.set_ylabel('y')
    # ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
    # ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
    # ax_phase.set_xlim(-2, 2)
    # ax_phase.set_ylim(-2, 2)

    # ax_vecfield.cla()
    fig6 = plt.figure(6)
    plt.title('Learned Vector Field')
    plt.xlabel('x')
    plt.ylabel('y')

    y, x = np.mgrid[-2:2:21j, -2:2:21j]
    # print("hello",np.stack([x, y], -1).shape)
    dydt = odefunc(torch.tensor(0), rearrange(torch.Tensor(np.stack([x, y, np.zeros_like(y)], -1).reshape(21 * 21, 3)),'b d -> b 1 d').to(device)).cpu().detach().numpy()
    # print("dydt",dydt.shape)
    dydt = rearrange(dydt,'b 1 d -> b d')
    mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
    dydt = (dydt / mag)
    dydt = dydt.reshape(21, 21, 3)

    plt.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")

    # fig.tight_layout()
    plt.savefig(args.fig_save_path+'/{:03d}'.format(itr))
    plt.show()
    plt.draw()
    plt.waitforbuttonpress()


    
class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

    def get_avg(self):
        return self.avg


if __name__ == '__main__':

    ii = 0
    jj = 0

    # load yaml file
    # config = yaml.safe_load(open(args.training_config))

    # writer = init_writer(config)

    
    # true_y0 = rearrange(true_y0,'b d -> b 1 d')

    if args.model_type == 'LN_ode':
        func = LNODEFunc(device=device).to(device)
    elif args.model_type == 'LN_ode2':
        func = LNODEFunc2(device=device).to(device)
    elif args.model_type == 'LN_ode3':
        func = LNODEFunc3(device=device).to(device)
    elif args.model_type == 'LN_ode4':
        func = LNODEFunc4(device=device).to(device)
    elif args.model_type == 'LN_ode5':
        func = LNODEFunc5(device=device).to(device)
    elif args.model_type == 'LN_ode6':
        func = LNODEFunc6(device=device).to(device)
    elif args.model_type == 'LN_ode7':
        func = LNODEFunc7(device=device).to(device)
    elif args.model_type == 'LN_ode8':
        func = LNODEFunc8(device=device).to(device)
    elif args.model_type == 'LN_ode9':
        func = LNODEFunc9(device=device).to(device)
    elif args.model_type == 'LN_ode10':
        func = LNODEFunc10(device=device).to(device)
    elif args.model_type == 'LN_ode11':
        func = LNODEFunc11(device=device).to(device)
    elif args.model_type == 'neural_ode':
        func = ODEFunc().to(device)

    checkpoint = torch.load(args.model_load_path)
    func.load_state_dict(checkpoint['model_state_dict'],strict=True)

    end = time.time()

    time_meter = RunningAverageMeter(1.00)
    

    ''' 
    N: number of testing data
    J: number of end time being tested (5)
    M_j: data size for j-th time
    D: dimension (3)
    testing_y0: (N, D) for each time j start from the same y0
    testing_y: (J, (N, M_j, D))
    '''
    if not args.no_quantitative:
        for j, t_j in enumerate(t):
            with torch.no_grad():
                pred_y = odeint(func, testing_y0.unsqueeze(1), t_j).to(device)  # Lie Neuron takes additional feature dimension thus need to unsqueeze(1)
                cur_y = testing_y[j]    # 
                loss = torch.mean(torch.norm(pred_y - cur_y, dim=-1))

                time_meter.update(time.time() - end)
                print("loss at",t_end_list[j],"sec is", loss.item())

    if not args.no_test_augmentation:
        # generate data for change of reference frame testing
        hat_so3 = HatLayer(algebra_type='so3')
        v = torch.rand((args.num_testing_augmentation,3))
        v = torch.div(v,torch.norm(v,dim=-1).unsqueeze(-1))
        phi = (math.pi-1e-6)*torch.rand(args.num_testing_augmentation,1)
        v = phi*v
        v_hat = hat_so3(v)
        R = exp_so3(v_hat).to(device)

        for j, t_j in enumerate(t):
            cur_y = testing_y[j]    # (N, M_j, D)
            loss_meter = RunningAverageMeter(1.00)
            for k in range(args.num_testing_augmentation):
                cur_R = R[k,:,:]
                # print("cur_R",cur_R.shape)
                # print("cur_y",cur_y.shape)
                # print("testing_y0",testing_y0.shape)
                cur_y0 = torch.einsum('ij,kj->ki',cur_R,testing_y0).unsqueeze(1)
                cur_y_conj = torch.einsum('ij,bksj->bksi',cur_R,cur_y)
                func.set_R(cur_R)
                with torch.no_grad():
                    pred_y = odeint(func, cur_y0, t_j).to(device)
                    loss = torch.mean(torch.norm(pred_y - cur_y_conj, dim=-1))
                    loss_meter.update(loss.item())

            print("conjugated loss at",t_end_list[j],"sec is", loss_meter.get_avg())

    if args.viz:
        with torch.no_grad():
            func.set_R(torch.eye(3).to(device))
            vis_pred_y = odeint(func, vis_y0.unsqueeze(0), vis_t).squeeze(1).to(device)
            print(vis_pred_y.shape)
            visualize(vis_y, vis_pred_y, func, ii)
            ii += 1
        
    
    # visualize(val_true_y, pred_y, func, ii)
    # ii += 1

            
    end = time.time()
