import numpy as np
import scipy as sp
import gc
import torch
import trimesh
import pickle
import ot
import sys
sys.path.append('../utils/SGW/lib')
sys.path.append('..')

from QDOT.QDOT_numpy import *
from risgw import risgw_gpu
from sgw_numpy import sgw_cpu

from joblib import Parallel, delayed
import os

def rotation_matrix_x(theta):
    return np.array([[1, 0, 0],
                     [0, np.cos(theta), -np.sin(theta)],
                     [0, np.sin(theta), np.cos(theta)]])

def rotation_matrix_y(theta):
    return np.array([[np.cos(theta), 0, np.sin(theta)],
                     [0, 1, 0],
                     [-np.sin(theta), 0, np.cos(theta)]])

def rotation_matrix_z(theta):
    return np.array([[np.cos(theta), -np.sin(theta), 0],
                     [np.sin(theta), np.cos(theta), 0],
                     [0, 0, 1]])

def rotate_vector(v, theta_x, theta_y, theta_z):
    Rx = rotation_matrix_x(theta_x)
    Ry = rotation_matrix_y(theta_y)
    Rz = rotation_matrix_z(theta_z)
    R = np.dot(Rz, np.dot(Ry, Rx))
    
    return np.dot(v, R)

def TMSE(Points, Couplings):
    dist = sp.spatial.distance.cdist(Points, Points)
    loss = np.sum(Couplings * dist)
    return loss

def IR(Points, Couplings):
    dist = sp.spatial.distance.cdist(Points, Points)
    threshold = np.quantile(sp.spatial.distance.cdist(Points, Points).reshape(-1), 0.2)
    IR = ((dist < threshold)*Couplings).sum()    
    return IR


def Geo_Compare(X, Y, method = 'QDOT', rep_dim = 100, return_coupling = False):
    if(method == 'QDOT'):
        # QDOT by EMD solver
        loss, P = QDOT(X, Y, n_quantile = rep_dim, initial = False, sigma = 1000)
        # try loss, P = QDOT(X, Y, n_quantile = rep_dim, EMD = False, initial = False, sigma = 1000) for a Sinkhorn solver
    elif(method == 'IQDOT'):
        loss = QDOT(X, Y, n_quantile = rep_dim, intergal = True, initial = False, sigma = 1000)
        P = None
    elif(method == 'GW'):
        C1 = sp.spatial.distance.cdist(X, X)
        C2 = sp.spatial.distance.cdist(Y, Y)
        C1 /= C1.max()
        C2 /= C2.max()
        p = ot.unif(len(X))
        q = ot.unif(len(Y))
        P, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, verbose=False, log=True)
        loss = log['gw_dist']
    elif(method == 'EGW'):
        C1 = sp.spatial.distance.cdist(X, X)
        C2 = sp.spatial.distance.cdist(Y, Y)
        C1 /= C1.max()
        C2 /= C2.max()
        p = ot.unif(len(X))
        q = ot.unif(len(Y))
        reg = 0.1
        P, log = ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', epsilon=reg, log=True)
        loss = log['gw_dist']
    elif(method == 'RISGW'):
        device = torch.device('cpu')
        loss = risgw_gpu(torch.from_numpy(X).to(torch.float32).to(device),
                         torch.from_numpy(Y).to(torch.float32).to(device),
                         device,nproj=rep_dim)
        P = None
    elif(method == 'SGW'):
        loss = sgw_cpu(X, Y, nproj = rep_dim)
        P = None
    else:
        raise ValueError('Method not recognized. Choose from QDOT, IQDOT, GW, EGW, RISGW.')
    
    if(return_coupling):
        return loss, P
    else:
        return loss

def _run_one_task_joblib(i, task_path, reference, class_name, coupling_compare = True):
    try:
        filename = f"{task_path}{i:02d}.obj"
        y_dist = np.array(trimesh.load(filename, force='mesh').vertices)[::2]

        if(coupling_compare):
            loss0, P0 = Geo_Compare(reference,            y_dist, method=class_name, return_coupling=True)
            loss3d0 = TMSE(reference, P0)
            ir0 = IR(reference, P0)
            del P0
            gc.collect()

            loss1, P1 = Geo_Compare(reference[:, [0, 1]], y_dist, method=class_name, return_coupling=True)
            loss3d1 = TMSE(reference, P1)
            ir1 = IR(reference, P1)
            del P1
            gc.collect()

            loss2, P2 = Geo_Compare(reference[:, [0, 2]], y_dist, method=class_name, return_coupling=True)
            loss3d2 = TMSE(reference, P2)
            ir2 = IR(reference, P2)
            del P2
            gc.collect()

            loss3, P3 = Geo_Compare(reference[:, [1, 2]], y_dist, method=class_name, return_coupling=True)
            loss3d3 = TMSE(reference, P3)
            ir3 = IR(reference, P3)
            del P3
            gc.collect()

            out = {
                'i': i,
                'loss':  (loss0, loss1, loss2, loss3),
                'loss3d':(loss3d0, loss3d1, loss3d2, loss3d3),
                'ir':    (ir0, ir1, ir2, ir3),
            }
        else:
            loss0 = Geo_Compare(reference,            y_dist, method=class_name, return_coupling=False)
            loss1 = Geo_Compare(reference[:, [0, 1]], y_dist, method=class_name, return_coupling=False)
            loss2 = Geo_Compare(reference[:, [0, 2]], y_dist, method=class_name, return_coupling=False)
            loss3 = Geo_Compare(reference[:, [1, 2]], y_dist, method=class_name, return_coupling=False)
            
            out = {
                'i': i,
                'loss':  (loss0, loss1, loss2, loss3),
                'loss3d':(0, 0, 0, 0),
                'ir':    (0, 0, 0, 0),
            }
        return out
    except Exception:
        return {'i': i, 'error': traceback.format_exc()}

class CrossSpaceTask:
    def __init__(self, class_name, num_tasks = 48):
        self.class_name = class_name
        LOSS = {'Name': self.class_name, 
                'Num_tasks': num_tasks,
                'Loss': np.zeros([4, num_tasks]),
                'TMSE': np.zeros([4, num_tasks]),
                'IR': np.zeros([4, num_tasks]),
                'Couplings': [[], [], [], []]}
        self.LOSS = LOSS
        
    def Run_tasks(self,
                ref_parh = "./horse-gallop/horse-gallop-reference.obj",
                task_path = "./horse-gallop/horse-gallop-", 
                coupling_compare = True):

        reference = np.array(trimesh.load(ref_parh, force='mesh').vertices)
        reference = rotate_vector(reference, 45, 45, 0)[::2]

        num_tasks = int(self.LOSS['Num_tasks'])

        results = Parallel(n_jobs=16, verbose=5)(
            delayed(_run_one_task_joblib)(i, task_path, reference, self.class_name, coupling_compare)
            for i in range(1, num_tasks + 1)
        )

        for out in results:
            i = out['i']
            if 'error' in out:
                raise RuntimeError(f"Task {i} failed:\n{out['error']}")

            idx = i - 1
            l0, l1, l2, l3 = out['loss']
            g0, g1, g2, g3 = out['loss3d']
            ir0, ir1, ir2, ir3 = out['ir']

            self.LOSS['Loss'][0, idx] = l0
            self.LOSS['Loss'][1, idx] = l1
            self.LOSS['Loss'][2, idx] = l2
            self.LOSS['Loss'][3, idx] = l3

            self.LOSS['TMSE'][0, idx] = g0
            self.LOSS['TMSE'][1, idx] = g1
            self.LOSS['TMSE'][2, idx] = g2
            self.LOSS['TMSE'][3, idx] = g3

            self.LOSS['IR'][0, idx] = ir0
            self.LOSS['IR'][1, idx] = ir1
            self.LOSS['IR'][2, idx] = ir2
            self.LOSS['IR'][3, idx] = ir3

        self.LOSS['Loss'] = self.LOSS['Loss'] / self.LOSS['Loss'].max(axis=1)[:, np.newaxis]
    
    def Save_results(self, save_path = 'outputs/'):
        with open(f'{save_path}{self.LOSS["Name"]}_Loss.pkl', 'wb') as f:
            pickle.dump(self.LOSS, f)
    
    def Load_results(self, load_path = 'outputs/'):
        with open(f'{load_path}{self.LOSS["Name"]}_Loss.pkl', 'rb') as f:
            self.LOSS = pickle.load(f)
