import sys
import os
from functools import reduce
import argparse

# Argument parser setup
parser = argparse.ArgumentParser(description='Run Main Experiments')

# Adding arguments for hyperparameters with their default values
parser.add_argument('--path_save', type=str, default="", help='Path where the results will be saved')
parser.add_argument('--batch_size', type=int, default=20, help='Batch size for training')
parser.add_argument('--lr', type=float, default=50, help='Learning rate')
parser.add_argument('--t_epochs', type=int, default=20, help='Number of training epochs')
parser.add_argument('--tau', type=float, default=1e-4, help='Norm regularization parameter')
parser.add_argument('--beta', type=float, default=1e-6, help='Noise parameter')
parser.add_argument('--data_std', type=float, default=4, help='Standard deviation of the data')
parser.add_argument('--granular', type=int, default=5, help='Granularity for saving particles')
parser.add_argument('--teacher_mode', type=str, default='free', help='Teacher mode (e.g. "free", "weak", "strong")')
parser.add_argument('--n_reps', type=int, default=10, help='Number of times to repeat the experiments')
parser.add_argument('--n_p', type=int, default=5, help='Number of particles of the student network')
parser.add_argument('--n_t', type=int, default=5, help='Number of particles for the teacher_network')
parser.add_argument('--fix_teacher', type=bool, default=True, help='Boolean flag for Fixing the Teacher Particles')

args = parser.parse_args()

# Extract arguments
BATCH_SIZE = args.batch_size
LR = args.lr
T_EPOCHS = args.t_epochs
TAU = args.tau
BETA = args.beta
DATA_STD = args.data_std
GRANULAR = args.granular
EQUIV_INIT = False
TEACHER_MODE = args.teacher_mode
N_reps = args.n_reps
N_p = args.n_p
N_t = args.n_t
FIX_TEACHER = args.fix_teacher
path_save = args.path_save

import jax.numpy as jnp
from jax import vmap

import objax
from emlp.reps import V,sparsify_basis
from emlp.groups import SO,O,S,Z

import src.modules
from src.visualization import vis, particle_plot, plot_losses, particle_plot_animation
from src.theory_utils import equivariance_err, Wasserstein_Distance, rel_measure_distance
from src.modules import ShallowMLPNoLinearOut, FA_Model
from src.train_eval_utils import random_compare, training_loop
from src.utils import ExpData, CumData


G = S(2)
repin = V(G)
repout = V(G)
rep_params = (repin>>repout)
P_params = rep_params.equivariant_projector()
base_params = rep_params.equivariant_basis()
G_generator = rep_params.rho_dense(G.discrete_generators[0])


# Vectorized application of maps
vP_params = vmap(lambda x: P_params@x)
vbase_params = vmap(lambda x: base_params@x)
vbase_paramsT = vmap(lambda x: base_params.T@x)

# The orbit maps are, unfortunately, restricted to this specific case of S(2)
vG_generator = vmap(lambda x: jnp.dot(G_generator,x))
vorbit = (lambda x: jnp.vstack([x, vG_generator(x)])) # This generates an array of "double the amount of particles", but with the complete orbit of each point.

scale_factor = 0.5

def MultiDirectionProjection(directions):
  # Normalize the directions
  normalized_directions = directions / jnp.linalg.norm(directions, axis=1, keepdims=True)
  
  # Compute the Gram-Schmidt orthogonalization
  Q, _ = jnp.linalg.qr(normalized_directions.T)
  
  # Compute the projection matrix
  P = Q @ Q.T
  
  # Create a vectorized projection function
  vP_oblique_params = vmap(lambda x: P @ x)
  
  return P, vP_oblique_params



def create_model(N_p, activation_fn, mode="free", fixed_init = None):
  if mode in ["strong","strong-equivariant"]:
      model = ShallowMLPNoLinearOut(N=N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=True)
      init_particles = model.get_particles() if fixed_init is None else fixed_init
  elif mode in ["weak", "weak-equivariant"]:
      model = ShallowMLPNoLinearOut(N=2*N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)
      init_particles = vorbit(model.get_particles()) if fixed_init is None else vorbit(fixed_init)
  else:
      model = ShallowMLPNoLinearOut(N=N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)
      init_particles = model.get_particles() if fixed_init is None else fixed_init
  model.set_particles(init_particles)
  return model
  
def create_Heuristic(N_p, activation_fn, proj = None, seed=0):
  # We consider a "free" network, NOT constrained to staying within the space of Equivariant parameters.
  model = ShallowMLPNoLinearOut(N=N_p, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)

  initial_w = objax.random.normal(model.get_particles().shape, generator = objax.random.Generator(seed=seed))/4
  if not proj is None:
    initial_w = proj(initial_w)
  model.set_particles(initial_w)
  return model
  
def equiv_error(models, labels, sample):
  return dict([(label, equivariance_err(sample, model, G, repin, repout).item()) for model, label in zip(models, labels)])
  
def teacher_error(model_t, models_s, labels, sample):
    x = sample
    if not isinstance(models_s, list):
        models_s = [models_s]
    y_t, L_s = model_t(x), [model_s(x) for model_s in models_s]
    diff_t_s = [jnp.sqrt(((y_s-y_t)**2).mean(axis=1)) for y_s in L_s]
    return dict([(labels[j], diff.mean().item()) for j, diff in enumerate(diff_t_s)])
    
def new_rel_measure_distance(m1, m2, p=2, root = False, mode = "trace", return_Wasserstein=False):
    W_dist = Wasserstein_Distance(m1, m2, p=p, root=root)
    mom1 = ((m1**p).sum(axis=1)).mean()
    mom2 = ((m2**p).sum(axis=1)).mean()
    total_variation = (mom1 + mom2) if not root else jnp.power((mom1 + mom2), 1/p)
    return (2*W_dist/total_variation).item() if not return_Wasserstein else ((2*W_dist/total_variation).item(), W_dist)
    
def distances(particles1, particles2, new=False):
    RM = []
    WD = []
    for p1, p2 in zip(particles1, particles2):
      rm, wd = rel_measure_distance(p1, p2, root=False, mode="trace", return_Wasserstein=True) if not new else new_rel_measure_distance(p1, p2, root=False, mode="trace", return_Wasserstein=True)
      RM.append(rm)
      WD.append(wd)
    return RM, WD

def one_trial_Heuristic(teacher_network, N_p, activation_fn, seed, train_params, new =False):
    BATCH_SIZE, LR, T_EPOCHS = train_params["BATCH_SIZE"], train_params["LR"], train_params["T_EPOCHS"]
    TAU, BETA, DATA_STD = train_params["TAU"], train_params["BETA"], train_params["DATA_STD"]
    GRANULAR, P_MAP =  train_params["GRANULAR"], train_params["P_MAP"]
    EPOCHS = T_EPOCHS*N_p
    PRINT_EVERY = EPOCHS/2 #5000
    LABELS = ["vanilla"]

    vPzeros = lambda x: jnp.zeros(x.shape)
    model =  create_Heuristic(N_p, activation_fn, proj=vPzeros, seed=seed)
    test_sample = objax.random.normal((100,2), stddev=DATA_STD)

    eq_error, dist_teacher, comparisons_RMD, comparisons_WD = {}, {}, {}, {}
    eq_error["fase0_s"] = equiv_error(models=[model], labels=LABELS, sample=test_sample)
    dist_teacher["fase0_s"] = teacher_error(teacher_network, [model], labels = LABELS, sample=test_sample)
    comparisons_RMD["fase0_s"], comparisons_WD["fase0_s"] = RMD_comparisons_Heuristic([model.get_particles()], vPzeros, new=new)

    train_losses0, particles0 = training_loop(modelo=model, teacher_network=teacher_network, N=N_p, input_dim=repin.size(), BS=BATCH_SIZE, lr=LR, NUM_EPOCHS=EPOCHS, tau=TAU, beta=0, stddev=DATA_STD, print_every=PRINT_EVERY, proj_map=vPzeros, seed = SEED, return_particles=True)
    train_losses0, particles0 = train_losses0[::(N_p//GRANULAR)], particles0[::(N_p//GRANULAR)]

    eq_error["fase0_e"] = equiv_error(models=[model], labels=LABELS, sample=test_sample)
    dist_teacher["fase0_e"] = teacher_error(teacher_network, [model], labels = LABELS, sample=test_sample)
    comparisons_RMD["fase0_e"], comparisons_WD["fase0_e"] = RMD_comparisons_Heuristic(particles0, vPzeros, new=new)

    direction_previous = particles0[-1].mean(axis=0)
    #print(direction_previous)
    #vP_oblique_params = vmap(lambda x: jnp.hstack([jnp.expand_dims(direction_previous/(jnp.sqrt((direction_previous**2).sum())), 1), jnp.zeros((4,3))])@x)
    #vP_oblique_params = vmap(lambda x: direction_previous* jnp.dot(x, direction_previous) / jnp.linalg.norm(direction_previous))
    #norm_dir_prev = direction_previous/jnp.linalg.norm(direction_previous)
    #P = jnp.outer(norm_dir_prev, norm_dir_prev)
    #vP_oblique_params = vmap(lambda x: P@x)
    vP_oblique_params = MultiDirectionProjection(direction_previous.reshape(-1,4))[1]

    model =  create_Heuristic(N_p, activation_fn, proj=vP_oblique_params, seed=seed)
    eq_error["fase1_s"] = equiv_error(models=[model], labels=LABELS, sample=test_sample)
    dist_teacher["fase1_s"] = teacher_error(teacher_network, [model], labels = LABELS, sample=test_sample)
    comparisons_RMD["fase1_s"], comparisons_WD["fase1_s"] = RMD_comparisons_Heuristic([model.get_particles()], vP_oblique_params, new=new)
    #print(model.get_particles(), vP_oblique_params(model.get_particles()))
    train_losses1, particles1 = training_loop(modelo=model, teacher_network=teacher_network, N=N_p, input_dim=repin.size(), BS=BATCH_SIZE, lr=LR, NUM_EPOCHS=EPOCHS, tau=TAU, beta=BETA, stddev=DATA_STD, print_every=PRINT_EVERY, proj_map=vP_oblique_params, seed = SEED, return_particles=True)
    train_losses1, particles1 = train_losses1[::(N_p//GRANULAR)], particles1[::(N_p//GRANULAR)]

    eq_error["fase1_e"] = equiv_error(models=[model], labels=LABELS, sample=test_sample)
    dist_teacher["fase1_e"] = teacher_error(teacher_network, [model], labels = LABELS, sample=test_sample)
    comparisons_RMD["fase1_e"], comparisons_WD["fase1_e"] = RMD_comparisons_Heuristic(particles1, vP_oblique_params, new=new)

    new_dir = (particles1[-1] - vP_oblique_params(particles1[-1])).mean(axis=0)
    dirs = jnp.vstack([direction_previous, new_dir])
    vP_newdir_params = MultiDirectionProjection(dirs)[1]

    model =  create_Heuristic(N_p, activation_fn, proj=vP_newdir_params, seed=seed)
    eq_error["fase2_s"] = equiv_error(models=[model], labels=LABELS, sample=test_sample)
    dist_teacher["fase2_s"] = teacher_error(teacher_network, [model], labels = LABELS, sample=test_sample)
    comparisons_RMD["fase2_s"], comparisons_WD["fase2_s"] = RMD_comparisons_Heuristic([model.get_particles()], vP_newdir_params, new=new)

    train_losses2, particles2 = training_loop(modelo=model, teacher_network=teacher_network, N=N_p, input_dim=repin.size(), BS=BATCH_SIZE, lr=LR, NUM_EPOCHS=EPOCHS, tau=TAU, beta=BETA, stddev=DATA_STD, print_every=PRINT_EVERY, proj_map=vP_newdir_params, seed = SEED, return_particles=True)
    train_losses2, particles2 = train_losses2[::(N_p//GRANULAR)], particles2[::(N_p//GRANULAR)]

    eq_error["fase2_e"] = equiv_error(models=[model], labels=LABELS, sample=test_sample)
    dist_teacher["fase2_e"] = teacher_error(teacher_network, [model], labels = LABELS, sample=test_sample)
    comparisons_RMD["fase2_e"], comparisons_WD["fase2_e"] = RMD_comparisons_Heuristic(particles2, vP_newdir_params, new=new)

    train_losses = {"fase0":train_losses0, "fase1":train_losses1, "fase2":train_losses2}
    particles = {"fase0":particles0, "fase1":particles1, "fase2":particles2}

    return eq_error, dist_teacher, comparisons_RMD, comparisons_WD, train_losses, particles

def RMD_comparisons_Heuristic(particles, proj, new=False):
    # Maybe calcular distancia de Wasserstein al teacher en términos de las partículas mismas...
    comparisons_RMD = {}
    comparisons_WD = {}
    vvproj = lambda x: vmap(proj)(jnp.array(x))
    vvP_params =  lambda x: vmap(vP_params)(jnp.array(x))
    vvorbit = lambda x: vmap(vorbit)(jnp.array(x))

    comparisons_RMD["V vs. P_0(V)"], comparisons_WD["V vs. P_0(V)"] = distances(particles, vvproj(particles), new=new)
    comparisons_RMD["V vs. G(V)"], comparisons_WD["V vs. G(V)"] = distances(particles, vvorbit(particles), new=new)
    comparisons_RMD["V vs. P_E(V)"], comparisons_WD["V vs. P_E(V)"] = distances(particles, vvP_params(particles), new=new)
    return comparisons_RMD, comparisons_WD
    
    
activation_fn = objax.functional.sigmoid # objax.functional.tanh # objax.functional.selu # None

LABELS = ["vanilla"]
train_params = dict(
    BATCH_SIZE = BATCH_SIZE,
    LR = LR,
    T_EPOCHS = T_EPOCHS,
    TAU = TAU, # This is the "norm" regularization parameter
    BETA = BETA, # This is the "noise" parameter
    DATA_STD = DATA_STD,
    GRANULAR = GRANULAR,
    EQUIV_INIT=EQUIV_INIT,
    TEACHER_MODE=TEACHER_MODE,
    N_reps = N_reps)
train_params["P_MAP"] = vP_params if EQUIV_INIT else None

if FIX_TEACHER and N_t ==5:
  #Fixed Teacher
  if TEACHER_MODE == "strong":
    fixed_teacher_particles = (scale_factor)*jnp.array([[1,0],[0.5,1],[-0.5,0.3],[0,-1], [0.7, 0.7]])
  else:
    fixed_teacher_particles = (scale_factor)*jnp.array([[-1,0,0,0.5],[0.5,1,0,1],[-0.5,0.3,1,0],[0,-1,-0.5,1], [0.7, -0.7,0.5,0.7]])
else:
  fixed_teacher_particles = None
teacher_network = create_model(N_t, activation_fn, mode= TEACHER_MODE, fixed_init=fixed_teacher_particles)
train_params["TEACHER_FIXED"] = not fixed_teacher_particles is None

NEW = True


out_dict = ExpData(["eq_errors", "dist_to_teacher", "comparisons_RMD", "comparisons_WD",
                      "train_losses", "particles"], N_p, train_params)
for SEED in range(N_reps):
    trial_out = one_trial_Heuristic(teacher_network, N_p, activation_fn, SEED, train_params, new=NEW)
    out_dict.append(trial_out)

out_dict.save(path_save +"Heuristic")