# structure of the code is derived from https://github.com/andyjm3/rhgd
import torch
from torch import nn
import time
import numpy as np
from geoopt import linalg
import pickle

import os
import sys
from datetime import datetime

notebook_dir = os.getcwd()
rhgd_dir = notebook_dir + '/rhgd'
sys.path.append(rhgd_dir)

from torch.autograd import Function as F
from geoopt import linalg
import argparse

from geoopt import Stiefel, ManifoldParameter
from rhgd.manifolds import EuclideanMod
from rhgd.optimizer import RHGDstep
from scipy.io import loadmat
from src.AdaRHD_utils import compute_hypergrad_v

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

def vec(X):
    """Reshape a symmetric matrix into a vector by extracting its upper-triangular part"""
    d = X.shape[-1]
    return X[..., torch.triu_indices(d, d)[0], torch.triu_indices(d, d)[1]]

class Re_op():
    """ Relu function and its derivative """
    _threshold = 1e-4

    @classmethod
    def fn(cls, S, param=None):
        return nn.Threshold(cls._threshold, cls._threshold)(S)

    @classmethod
    def fn_deriv(cls, S, param=None):
        return (S > cls._threshold).float()

def add_id_matrix(P, alpha):
    '''
    Input P of shape (batch_size,1,n,n)
    Add Id
    '''
    P = P + alpha * P.trace() * torch.eye(
        P.shape[-1], dtype=P.dtype, device=P.device)
    return P

def modeig_forward(P, op, eig_mode='svd', param=None):
    '''
    Generic forward function of non-linear eigenvalue modification
    LogEig, ReEig, etc inherit from this class
    Input P: (batch_size,channels) SPD matrices of size (n,n)
    Output X: (batch_size,channels) modified symmetric matrices of size (n,n)
    '''
    batch_size, channels, n, n = P.shape
    U, S = torch.zeros_like(P, device=P.device), torch.zeros(batch_size,
                                                       channels,
                                                       n,
                                                       dtype=P.dtype,
                                                       device=P.device)
    for i in range(batch_size):
        for j in range(channels):
            if (eig_mode == 'eig'):
                s, U[i, j] = torch.eig(P[i, j], True)
                S[i, j] = s[:, 0]
            elif (eig_mode == 'svd'):
                U[i, j], S[i, j], _ = torch.svd(add_id_matrix(P[i, j], 1e-5))
    S_fn = op.fn(S, param)
    X = U.matmul(BatchDiag(S_fn)).matmul(U.transpose(2, 3))
    return X, U, S, S_fn


def BatchDiag(P):
    """
    Input P: (batch_size,channels) vectors of size (n)
    Output Q: (batch_size,channels) diagonal matrices of size (n,n)
    """
    batch_size, channels, n = P.shape  #batch size,channel depth,dimension
    Q = torch.zeros(batch_size, channels, n, n, dtype=P.dtype, device=P.device)
    for i in range(batch_size):  #can vectorize
        for j in range(channels):  #can vectorize
            Q[i, j] = P[i, j].diag()
    return Q

def modeig_backward(dx, U, S, S_fn, op, param=None):
    '''
    Generic backward function of non-linear eigenvalue modification
    LogEig, ReEig, etc inherit from this class
    Input P: (batch_size,channels) SPD matrices of size (n,n)
    Output X: (batch_size,channels) modified symmetric matrices of size (n,n)
    '''

    #print("Correct back prop")
    S_fn_deriv = BatchDiag(op.fn_deriv(S, param))
    SS = S[..., None].repeat(1, 1, 1, S.shape[-1])
    SS_fn = S_fn[..., None].repeat(1, 1, 1, S_fn.shape[-1])
    L = (SS_fn - SS_fn.transpose(2, 3)) / (SS - SS.transpose(2, 3))
    L[L == -np.inf] = 0
    L[L == np.inf] = 0
    L[torch.isnan(L)] = 0
    L = L + S_fn_deriv
    dp = L * (U.transpose(2, 3).matmul(dx).matmul(U))
    dp = U.matmul(dp).matmul(U.transpose(2, 3))
    return dp

class ReEig(F):
    """
    Input P: (batch_size,h) SPD matrices of size (n,n)
    Output X: (batch_size,h) of rectified eigenvalues matrices of size (n,n)
    """
    @staticmethod
    def forward(ctx, P):
        X, U, S, S_fn = modeig_forward(P, Re_op)
        ctx.save_for_backward(U, S, S_fn)
        return X

    @staticmethod
    def backward(ctx, dx):
        # if __debug__:
        #     import pydevd
        #     pydevd.settrace(suspend=False, trace_only_current_thread=True)
        U, S, S_fn = ctx.saved_variables
        return modeig_backward(dx, U, S, S_fn, Re_op)

def SPDnet(hparams, params, A):
    # A needs to be [batch, d, d]
    W1 = hparams[0]
    W2 = hparams[1]
    W3 = hparams[2]
    # 1st layer
    Z1 = W1.transpose(-1, -2) @ A @ W1
    A1 = ReEig.apply(Z1.unsqueeze(1))
    # 2nd layer
    Z2 = W2.transpose(-1,-2) @ A1.squeeze(1) @ W2
    A2 = ReEig.apply(Z2.unsqueeze(1))
    # 3rd layer
    Z3 = W3.transpose(-1,-2) @ A2.squeeze(1) @ W3
    Z = linalg.sym_logm(Z3)
    Z = vec(Z).squeeze()
    return Z

def loss_lower_mfd(hparams, params, data):
    data_X, data_y = data
    gamma = params[0]
    pred = SPDnet(hparams, params, data_X)
    loss = nn.functional.cross_entropy(pred @ gamma, data_y) + 0.5 * lam * torch.norm(gamma) ** 2
    return loss

def loss_upper_mfd(hparams, params, data):
    data_X, data_y = data
    gamma = params[0]
    pred = SPDnet(hparams, params, data_X)
    loss = nn.functional.cross_entropy(pred @ gamma, data_y)
    return loss

def compute_acc(network, hparams, params, data):
    data_X, data_y = data
    gamma = params[0]
    pred = network(hparams, params, data_X)
    logit = pred @ gamma
    y_pred = torch.argmax(logit, dim=1)
    acc = (y_pred == data_y).sum() / data_y.shape[0]
    return acc

# Function to load data from a directory
def load_data_from_directory(directory, data_list, labels_list):
    for label in os.listdir(directory):
        label_path = os.path.join(directory, label)
        if os.path.isdir(label_path):
            try:
                label_num = int(label)  # Convert folder name to integer label
                print(f"Processing label: {label_num}")
                
                for file_name in os.listdir(label_path):
                    if file_name.endswith('.mat'):
                        file_path = os.path.join(label_path, file_name)
                        try:
                            mat_data = loadmat(file_path)
                            # Assuming the actual data is in the first non-standard key
                            data_key = [k for k in mat_data.keys() if not k.startswith('__')][0]
                            data_list.append(mat_data[data_key])
                            labels_list.append(label_num)
                        except Exception as e:
                            print(f"Error loading {file_path}: {e}")
            except ValueError:
                print(f"Skipping non-numeric folder: {label}")

# Create a class to simultaneously write to stdout and a file
class Logger(object):
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')
        
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()  # Ensure output is written immediately
        
    def flush(self):
        # Needed for Python 3 compatibility
        self.terminal.flush()
        self.log.flush()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_ratio', type=float, default=0.05)
    parser.add_argument('--lower_iter', type=int, default=50)
    parser.add_argument('--epoch', type=int, default=301)
    parser.add_argument('--hygrad_opt', type=str, default='cg', choices=['hinv', 'cg', 'ns', 'ad'])
    parser.add_argument('--ns_gamma', type=float, default=0.05)
    parser.add_argument('--ns_iter', type=int, default=30)
    parser.add_argument('--cg_gamma', type=float, default=0.)
    parser.add_argument('--cg_iter', type=int, default=50)
    parser.add_argument('--compute_hg_error', type=bool, default=False)
    parser.add_argument('--stochastic', type=bool, default=False)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--runs', type=int, default=5)
    parser.add_argument('--log', type=bool, default=True)
    args = parser.parse_args()

    if args.log:
        # Create a log file with timestamp
        timestamp = datetime.now().strftime("%Y%m%d")
        log_filename = f"deep_hyrep_spd_robust_RHGD_{args.lower_iter}_log_{timestamp}.txt"
        sys.stdout = Logger(log_filename)
    
    # Define paths
    train_path = 'data/spdface_400_inter_histeq/train'
    val_path = 'data/spdface_400_inter_histeq/val'

    # Initialize lists to store data and labels
    data_X = []
    y = []
    data_X_test = []
    test_y = []

    # Load training and validation data
    print("Loading training data...")
    load_data_from_directory(train_path, data_X, y)
    print("Loading validation data...")
    load_data_from_directory(val_path, data_X_test, test_y)

    # Convert lists to numpy arrays
    data_X = np.array(data_X)
    y = np.array(y)
    data_X_test = np.array(data_X_test)
    test_y = np.array(test_y)

    data_X = torch.from_numpy(data_X).float()
    data_y = torch.from_numpy(y-1).long()
    data_X_test = torch.from_numpy(data_X_test).float()
    data_y_test = torch.from_numpy(test_y-1).long()

    # Print dataset statistics before filtering
    print(f"Original dataset size - Training: {len(data_X)}, Testing: {len(data_X_test)}")
    print(f"Original unique labels - Training: {torch.unique(data_y)}, Testing: {torch.unique(data_y_test)}")

    # Define which labels to keep
    i1, i2, i3, i4, i5, i6, i7 = 0, 1, 2, 3, 4, 5, 6,  # Labels to keep
    labels_to_keep = [i1, i2, i3, i4, i5, i6, i7]

    # Filter training data
    train_mask = torch.zeros(data_y.shape, dtype=torch.bool)
    for label in labels_to_keep:
        train_mask = train_mask | (data_y == label)
        
    data_X = data_X[train_mask]
    data_y = data_y[train_mask]

    # Filter test data
    test_mask = torch.zeros(data_y_test.shape, dtype=torch.bool)
    for label in labels_to_keep:
        test_mask = test_mask | (data_y_test == label)
        
    data_X_test = data_X_test[test_mask]
    data_y_test = data_y_test[test_mask]

    # Create mapping from original labels to new consecutive labels
    label_mapping = {original: idx for idx, original in enumerate(labels_to_keep)}

    # Map labels to 0, 1
    data_y = torch.tensor([label_mapping[label.item()] for label in data_y])
    data_y_test = torch.tensor([label_mapping[label.item()] for label in data_y_test])

    # Print dataset statistics after filtering and mapping
    print(f"Filtered dataset size - Training: {len(data_X)}, Testing: {len(data_X_test)}")
    print(f"Filtered unique labels - Training: {torch.unique(data_y)}, Testing: {torch.unique(data_y_test)}")
    print(f"Label mapping: {label_mapping}")

    num_class = len(labels_to_keep)
    lam = 0.1

    # Generate random seeds for multiple runs
    np.random.seed(args.seed)
    seeds = np.random.randint(0, 10000, size=args.runs).tolist()
    print(f"Using seeds: {seeds}")
    data_ratio = args.data_ratio

    for step_size in [5, 1, 0.5, 0.1, 0.05]:
        args.eta_x = args.eta_y = step_size
        print(f"Step size: {step_size}")
        print(f"Training with {data_ratio} of the data")
        acc_all_runs = []
        test_acc_all_runs = []
        loss_all_runs = []
        runtime_all_runs = []
        total_hgradnorm_all_runs = []
        val50_time_runs = []
        val70_time_runs = []
        val85_time_runs = []

        for run in range(args.runs):
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            print(seeds[run])
            print(device)
            torch.random.manual_seed(seeds[run])
            np.random.seed(seeds[run])
            torch.backends.cudnn.deterministic = True

            train_X = []
            val_X = []
            train_y = []
            val_y = []

            for c in labels_to_keep:
                data_sub = data_X[data_y == c]
                idx_all =  torch.randperm(data_sub.shape[0])
                train_idx = idx_all[:int(idx_all.shape[0] * data_ratio)]
                val_idx = idx_all[int(idx_all.shape[0] * data_ratio):int(idx_all.shape[0]*2*data_ratio)]
                train_X.append(data_sub[train_idx])
                val_X.append(data_sub[val_idx])

                train_y.extend([c]*train_idx.shape[0])
                val_y.extend([c]*val_idx.shape[0])

            train_X = torch.cat(train_X, dim=0)
            val_X = torch.cat(val_X, dim=0)
            train_y = torch.tensor(train_y, dtype=torch.long)
            val_y = torch.tensor(val_y, dtype=torch.long)

            idx_train = torch.randperm(train_y.shape[0])
            train_X = (train_X[idx_train]).to(device)
            train_y = train_y[idx_train].to(device)
            idx_val = torch.randperm(val_y.shape[0])
            val_X = val_X[idx_val].to(device)
            val_y = val_y[idx_val].to(device)
            test_X = data_X_test.to(device)
            test_y = data_y_test.to(device)


            euclidean = EuclideanMod(ndim=2)
            stiefel = Stiefel(canonical=False)

            d = 400
            d1 = 100
            d2 = 20
            d3 = 5

            print('initialization with manifold network')
            hparams = [ManifoldParameter(stiefel.random(d, d1, device=device), manifold=stiefel),
                ManifoldParameter(stiefel.random(d1, d2, device=device), manifold=stiefel),
                ManifoldParameter(stiefel.random(d2, d3, device=device), manifold=stiefel)]
            params = [ManifoldParameter(euclidean.random(int(d3 * (d3 + 1) / 2), num_class, device=device), manifold=euclidean)]
            mfd_params = [param.manifold for param in params]
            loss_upper = loss_upper_mfd
            loss_lower = loss_lower_mfd
            network = SPDnet

            data_lower = [train_X, train_y]
            data_upper = [val_X, val_y]
            data_test = [test_X, test_y]

            epochs_all = [0]
            loss_u_all = [loss_upper(hparams, params, data_upper).item()]
            runtime = [0]
            acc_all = [compute_acc(network, hparams, params, data_upper).item()]
            test_acc_all = [compute_acc(network, hparams, params, data_test).item()]
            hypergrad, _, _ = compute_hypergrad_v(loss_lower, loss_upper, hparams, params, 
                                    data_lower, data_upper, v0=None, hygrad_opt=args.hygrad_opt, true_hessinv=None,
                                    iter = 50, gd_gamma = None, v_reg=0.0, tol=1e-10, verbose=False)
            mfd_hparams = [hparam.manifold for hparam in hparams]
            hgradnorm = 0
            for mfd, hparam, hg in zip(mfd_hparams, hparams, hypergrad):
                hgradnorm += mfd.inner(hparam, hg).item() / len(hparams)
            total_hgradnorm_all = [hgradnorm * len(hparams)]
            print(f"trainning starts with initial hgradnorm: {total_hgradnorm_all[0]:.4e}")

            start_time = time.time()
            val50_time = None
            val70_time = None
            val85_time = None

            for ep in range(1, args.epoch+1):
                try:
                    hparams, params, loss_u, hgradnorm, step_time, hg_error = RHGDstep(loss_lower, loss_upper, hparams, params,
                                                                                        args,
                                                                                        data=[data_lower, data_upper],
                                                                                        )
                    total_hgradnorm = hgradnorm * len(hparams)
                    with torch.no_grad():
                        val_acc = compute_acc(network, hparams, params, data_upper)
                        test_acc = compute_acc(network, hparams, params, data_test)
                    
                    loss_u_all.append(loss_u)
                    runtime.append(step_time)
                    epochs_all.append(ep)
                    acc_all.append(val_acc.item())
                    test_acc_all.append(test_acc.item())
                    total_hgradnorm_all.append(total_hgradnorm)

                    print(f"Epoch {ep}: "
                            f"loss upper: {loss_u:.4e}, "
                            f"total_hgradnorm: {total_hgradnorm:.4e}, "
                            f"Val acc: {val_acc*100:.2f}, "
                            f"Test acc: {test_acc*100:.2f}, "
                            f"runtime: {step_time:.2f}")
                    
                    # Record when validation accuracy reaches specific thresholds
                    current_time = time.time() - start_time
                    if val50_time is None and val_acc >= 0.50:
                        val50_time = current_time
                        print(f"Validation accuracy reached 50% at {val50_time:.2f} seconds")
                    if val70_time is None and val_acc >= 0.70:
                        val70_time = current_time
                        print(f"Validation accuracy reached 70% at {val70_time:.2f} seconds")
                    if val85_time is None and val_acc >= 0.85:
                        val85_time = current_time
                        print(f"Validation accuracy reached 85% at {val85_time:.2f} seconds")
                
                except Exception as e:
                    print("Error during RHGD training step")
                    break

            acc_all_runs.append(torch.tensor(acc_all).unsqueeze(0))
            test_acc_all_runs.append(torch.tensor(test_acc_all).unsqueeze(0))
            loss_all_runs.append(torch.tensor(loss_u_all).unsqueeze(0))
            runtime_all_runs.append(torch.tensor(runtime).unsqueeze(0))
            total_hgradnorm_all_runs.append(torch.tensor(total_hgradnorm_all).unsqueeze(0))
            if val50_time is not None:
                val50_time_runs.append(torch.tensor(val50_time).unsqueeze(0))
            else:
                val50_time_runs.append(None)
            if val70_time is not None:
                val70_time_runs.append(torch.tensor(val70_time).unsqueeze(0))
            else:
                val70_time_runs.append(None)
            if val85_time is not None:
                val85_time_runs.append(torch.tensor(val85_time).unsqueeze(0))
            else:
                val85_time_runs.append(None)

            stats = {'runtime': runtime_all_runs, 'loss_upper': loss_all_runs,
                        'accuracy': acc_all_runs, 'test_accuracy': test_acc_all_runs,
                        'val50_time': val50_time_runs, 'val70_time': val70_time_runs, 'val85_time': val85_time_runs,
                        'total_hgradnorm': total_hgradnorm_all_runs}

            res_folder = './results/'
            if not os.path.exists(res_folder):
                os.makedirs(res_folder)
            
            filename = res_folder + 'hyrep_spd_RHGD_lr' + str(args.eta_x) + '_loweriter' + str(args.lower_iter) + 'data_ratio' + str(data_ratio) + '.pkl'

            with open(filename, 'wb') as handle:
                pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL)