from tkinter import dialog
from turtle import st
import numpy as np
import matplotlib.pyplot as plt

import torch
import copy

import os
from matplotlib.patches import Ellipse, Rectangle, Polygon

import argparse

def RBF_kernel(x, x_nn, sigma=0.01):
    return torch.exp(-((x.unsqueeze(1) - x_nn)**2).sum(dim=-1)/(2*sigma**2))

def SPD_inv(P):
    S, U = torch.linalg.eigh(P)
    return U@torch.diag_embed(
        1/S, offset=0, dim1=-2, dim2=-1
    )@U.permute(0, 2, 1)

def SPD_log(P):
    S, U = torch.linalg.eigh(P)
    eps = 1.0e-6
    return U@torch.diag_embed(
        torch.log(torch.clip(S, min=eps, max=np.inf)), offset=0, dim1=-2, dim2=-1
    )@U.permute(0, 2, 1)

def SPD_exp(S):
    S, U = torch.linalg.eigh(S)
    return U@torch.diag_embed(
        torch.exp(S), offset=0, dim1=-2, dim2=-1
    )@U.permute(0, 2, 1)

def SPD_sqrt(P):
    S, U = torch.linalg.eigh(P)
    return U@torch.diag_embed(
        torch.sqrt(S), offset=0, dim1=-2, dim2=-1
    )@U.permute(0, 2, 1)

def SPD_invsqrt(P):
    S, U = torch.linalg.eigh(P)
    eps = 1.0e-6
    return U@torch.diag_embed(
        1/torch.clip(torch.sqrt(S), min=eps, max=np.inf), offset=0, dim1=-2, dim2=-1
    )@U.permute(0, 2, 1)

def SPD_add_noise(P, noise=0.5, type='tangent_Gaussian'):
    if type == 'tangent_Gaussian':
        P_sqrt = SPD_sqrt(P)
        noise_vec1 = noise*torch.randn_like(P)
        noise_vec2 = noise*torch.randn(len(P), P.size(1))
        L = torch.tril(noise_vec1, diagonal=-1)/np.sqrt(2)
        D = torch.diag_embed(noise_vec2, offset=0, dim1=-2, dim2=-1)
        return P_sqrt@(SPD_exp((L + L.permute(0, 2, 1) + D)))@P_sqrt
    else:
        raise NotImplementedError

def MVKR_filter(x, P, k=20, sigma=0.5, optnum=100, device='cpu', n_iter=1000, step_size=0.01):
    # torch.tensor, x.size() = (bs, dim)
    # torch.tensor, P.size() = (bs, dim, dim)
     
    dim = x.size(1)

    # compute k-nn
    idxs = torch.topk(((x.unsqueeze(1) - x.unsqueeze(0))**2).mean(dim=-1), k+1, largest=False).indices[:, 1:]
    x_nn = x[idxs]
    P_nn = P[idxs]

    # solution 
    P_sol = []

    # optimization
    for bP, bP_nn, bx, bx_nn in zip(P.split(optnum), P_nn.split(optnum), x.split(optnum), x_nn.split(optnum)):
        bP_opt = copy.copy(bP).to(device)
        for _ in range(n_iter):
            PinvP = torch.einsum('nij, najk -> naik', SPD_inv(bP_opt), bP_nn.to(device))
            logPinvP = SPD_log(PinvP.reshape(-1, dim, dim)).reshape(-1, k, dim, dim)
            kernel = RBF_kernel(bx.to(device), bx_nn.to(device), sigma=sigma)
            gradG = torch.einsum('nij, najk -> naik', bP_opt, logPinvP * kernel.unsqueeze(-1).unsqueeze(-1)).mean(dim=1)
            P_sqrt = SPD_sqrt(bP_opt)
            P_inv_sqrt = SPD_invsqrt(bP_opt)
            gradG = gradG/torch.norm(gradG.view(len(gradG), -1), dim=1).unsqueeze(-1).unsqueeze(-1)
            bP_opt = P_sqrt@(SPD_exp(P_inv_sqrt@(step_size*gradG)@P_inv_sqrt))@P_sqrt
        P_sol.append(bP_opt)
    return torch.cat(P_sol, dim=0)
    

def PD_metric_to_ellipse(G, center, scale, **kwargs):
    # eigen decomposition
    eigvals, eigvecs = np.linalg.eigh(G)
    order = eigvals.argsort()[::-1]
    eigvals, eigvecs = eigvals[order], eigvecs[:, order]

    # find angle of ellipse
    vx, vy = eigvecs[:,0][0], eigvecs[:,0][1]
    theta = np.arctan2(vy, vx)

    # draw ellipse
    width, height = 2 * scale * np.sqrt(eigvals)
    return Ellipse(xy=center, width=width, height=height, angle=np.degrees(theta), **kwargs)

def S_linear_field(x, center, S0, velocity):
    # center.size() == (1, 2)
    # x.size() == (bs, 2)
    # S0.size() == (2, 2)
    # velocity.size() == (2, 2, 2) ~ (input dim x symmetric matrics)
    N = x.shape[0]
    S = torch.zeros(N,2,2) + S0.unsqueeze(0)
    S += torch.einsum('ni, ijk -> njk', (x - center), velocity)
    return S

def S_polar_field(x, center, S0, velocity, offset=0):
    # center.size() == (1, 2)
    # x.size() == (bs, 2)
    # S0.size() == (2, 2)
    # velocity.size() == (1, 2, 2) ~ (a symmetric matrix)
    # offset: a scalar
    N = x.shape[0]
    r = ((x-center)**2).sum(dim=-1).unsqueeze(-1)
    theta = torch.atan2((x[:,1]-center[:,1]), (x[:,0]-center[:,0]))
    
    R = torch.zeros(N,2,2)
    R[:,0,0] = torch.cos(theta)
    R[:,0,1] = torch.sin(theta)
    R[:,1,0] = -torch.sin(theta)
    R[:,1,1] = torch.cos(theta)    
    D = velocity*(r.view(-1,1,1) - offset)
    S = torch.bmm(torch.bmm(R, D), R.permute(0,2,1)) + S0.unsqueeze(0)
    return S

def P2_field_generator(type='1'):
    if type == '1':
        xs = torch.linspace(-1, 1, steps=5)
        ys = torch.linspace(-1, 1, steps=5)
        x, y = torch.meshgrid(xs, ys)
        x = torch.cat((x.reshape(-1,1), y.reshape(-1,1)), dim=1)

        S0 = torch.zeros(2,2)
        S0[0,0] = -2.5
        S0[1,1] = -2.5

        # 1
        vel = torch.tensor([
            [[1.0, 0.0], [0.0, 0.0]], 
            [[0.0, 0.0], [0.0, 0.0]]
        ])
        S_linear1 = S_linear_field(x, torch.zeros((1, 2)), S0, 0.7*vel)

        P2_1 = SPD_exp(S_linear1)

        # 2
        vel = torch.tensor([
            [[0.0, 0.0], [0.0, 0.0]], 
            [[0.0, 0.0], [0.0, 1.0]]
        ])
        S_linear2 = S_linear_field(x, torch.zeros((1, 2)), S0, 0.7*vel)

        P2_2 = SPD_exp(S_linear2)

        # 3
        vel = torch.tensor([
            [[-3.0, 0.0], [0.0, 0.0]]
        ])
        S_polar1 = S_polar_field(x, torch.zeros((1, 2)), S0, 0.3*vel)

        P2_p1 = SPD_exp(S_polar1)

        # 4
        vel = torch.tensor([
            [[3.0, 0.0], [0.0, 0.0]]
        ])
        S_polar2 = S_polar_field(x, torch.zeros((1, 2)), S0, 0.3*vel)

        P2_p2 = SPD_exp(S_polar2)

        x = torch.cat([
            x+torch.tensor([[1.5, -1]]),
            x+torch.tensor([[-1, 1.5]]),
            x+torch.tensor([[1.5, 1.5]]),
            x+torch.tensor([[-1, -1]])
        ], dim=0)

        P = torch.cat([
            P2_1, P2_2, P2_p1, P2_p2
        ], dim=0)
    else:
        raise NotImplementedError
    return (x, P)

def draw_P2_field(x, P, scale=0.1, save_path=None):
    f = plt.figure()
    alpha = 0.3
    plt.scatter(x[:,0], x[:,1], c='k', s=1)
    for P_, x_ in zip(P, x):
        e = PD_metric_to_ellipse(torch.inverse(P_), x_, scale, alpha=alpha)
        plt.gca().add_artist(e)
    plt.axis('equal')
    plt.axis('off')
    if save_path is not None:
        plt.savefig(save_path)
    
def draw_multiple_P2_fields(x, list_title, list_P, scale=0.1, xylim=((-5, 5), (-5, 5)), save_path=None):
    fig, axes = plt.subplots(1, len(list_P), figsize=(10, 10))
    for i, (title, P) in enumerate(zip(list_title, list_P)):
        if len(list_P) == 1:
            ax = axes
        else:
            ax = axes[i]
        alpha = 0.3
        ax.set_xlim(xylim[0])
        ax.set_ylim(xylim[1])
        ax.scatter(x[:,0], x[:,1], c='k', s=1)
        for P_, x_ in zip(P, x):
            e = PD_metric_to_ellipse(torch.inverse(P_), x_, scale, alpha=alpha)
            ax.add_artist(e)
        ax.set_aspect('equal')
        ax.set_title(title)
        ax.set_axis_off()
    plt.tight_layout()
    plt.subplots_adjust(left=0.0, bottom=0.0, right=1.0, top=1.0, wspace=-0.35, hspace=-0.1)
    if save_path is not None:
        fig.savefig(save_path)

def main(args):
    os.makedirs(args.save_path, exist_ok=True)
    if args.mode == 'data_generation':
        # data generation
        dict_synthetic_data = {}
        for data_num in ['1']:
            for noise_type in ['tangent_Gaussian']:
                for noise_level in [0.001, 0.005, 0.01, 0.05, 0.1]:
                    for run in [1, 2, 3, 4, 5]:
                        (x, P) = P2_field_generator(type=data_num)
                        P_noisy = SPD_add_noise(copy.copy(P), noise_level, type=noise_type)

                        save_path = os.path.join(args.save_path, f'{data_num}_{noise_type}_{noise_level}_{run}')
                        os.makedirs(save_path, exist_ok=True)
                        torch.save(x, os.path.join(save_path, 'x.pt'))
                        torch.save(P, os.path.join(save_path, 'P.pt'))
                        torch.save(P_noisy, os.path.join(save_path, 'P_noisy.pt'))
    else:
        # data load
        data_path = os.path.join(args.data_path, f'{args.data}_{args.noise_type}_{args.noise_level}_{args.run}')
        x = torch.load(os.path.join(data_path, 'x.pt'))
        P = torch.load(os.path.join(data_path, 'P.pt'))
        P_noisy = torch.load(os.path.join(data_path, 'P_noisy.pt'))

        draw_P2_field(x, P, scale=args.ellipse_scale, save_path=os.path.join(args.save_path, "ground_truth_field.png"))
        draw_P2_field(x, P_noisy, scale=args.ellipse_scale, save_path=os.path.join(args.save_path, "noisy_field.png"))

        # filtering
        if args.filtering_algorithm == 'MVKR':
            P_filtered = MVKR_filter(
                x, 
                P_noisy, 
                k=args.k, 
                sigma=args.sigma, 
                optnum=args.optnum, 
                device=args.device, 
                n_iter=args.n_iter, 
                step_size=args.step_size)
        else:
            raise NotImplementedError

        # draw filtered results
        draw_P2_field(x, P_filtered, scale=args.ellipse_scale, save_path=os.path.join(args.save_path, "filtered_field.png"))

        # save error metric
        F = SPD_log(SPD_inv(P_noisy)@P)
        init_error = torch.einsum('nij, nji -> n', F, F).mean()/2
        F = SPD_log(SPD_inv(P_filtered)@P)
        terminal_error = torch.einsum('nij, nji -> n', F, F).mean()/2

        with open(os.path.join(args.save_path, "logger.txt"), 'w') as f:
            f.write(f'initial_error: {init_error}' + '\n')
            f.write(f'terminal_error: {terminal_error}')
            
        # draw together
        draw_multiple_P2_fields(
            x, ['gt', 'noisy', 'filtered'], [P, P_noisy, P_filtered], 
            scale=args.ellipse_scale, save_path=os.path.join(args.save_path, "images.png"))

        return {
            'P_filtered': P_filtered,
            'error': terminal_error
        }

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    required_arg = parser.add_argument_group('required arguments')
    required_arg.add_argument('--data_path', type=str, default='./SPD_filtering_test')
    required_arg.add_argument('--save_path', type=str, default='./SPD_filtering_test')
    required_arg.add_argument('--data', type=str, default='1')
    required_arg.add_argument('--noise_type', type=str, default='tangent_Gaussian')
    required_arg.add_argument('--noise_level', type=float, default=0.01)
    required_arg.add_argument('--run', type=int, default=1)
    required_arg.add_argument('--filtering_algorithm', type=str, default='MVKR')
    required_arg.add_argument('--k', type=int, default=10)
    required_arg.add_argument('--sigma', type=int, default=1)
    required_arg.add_argument('--optnum', type=int, default=1000)
    required_arg.add_argument('--device', type=str, default='cpu')
    required_arg.add_argument('--n_iter', type=int, default=1000)
    required_arg.add_argument('--step_size', type=float, default=0.001)
    required_arg.add_argument('--ellipse_scale', type=float, default=0.05)
    required_arg.add_argument('--mode', type=str, default=None)

    args = parser.parse_args()
    main(args)