from functools import partial
from jax import random
import jax
import jax.numpy as jnp
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split

#from train_exp import LRU_CIFAR10, SequenceLayer_CIFAR10
from models_exp import (
    LRU_spiral,
    SequenceLayer_spiral
)
from LRU.lru.model import BatchClassificationModel
from utils_data import create_spiral_dataset
from tqdm import tqdm
from flax.training import train_state
import wandb
import argparse
import datetime
import os
import pprint
from matplotlib import pyplot as plt

import optax
import numpy as np
import random as rndm

import orbax.checkpoint
from flax.training import orbax_utils

from LRU.lru.train_helpers import (
    create_train_state,
    linear_warmup,
    cosine_annealing,
    constant_lr,
    reduce_lr_on_plateau,
    train_epoch,
    validate,
    prep_batch,
    get_bound_spiral
)

jax.config.update("jax_default_matmul_precision", "high")
#jax.config.update("jax_enable_x64", True)
#jax.config.update("jax_debug_nans", True)

def set_seed(seed: int = 42, cuda=False) -> None:
    np.random.seed(seed)
    rndm.seed(seed)
    torch.manual_seed(seed)

    if cuda:
        torch.cuda.manual_seed(seed)
        # When running on the CuDNN backend, two further options must be set
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

     # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

@jax.jit
@jax.vmap
def norm_2_2(u):
    # u.shape == (seq_len, dim)
    return jnp.sqrt(jnp.sum(jnp.linalg.norm(u, ord=2, axis=-1)**2))

@jax.jit
@jax.vmap
def norm_2_2_squared(u):
    # u.shape == (seq_len, dim)
    return jnp.sum(jnp.linalg.norm(u, ord=2, axis=-1)**2)

#@jax.jit
#def l2_loss(y_true, y_pred):
#    # y shape:  (bs, seq_len, output_dim)
##    return jnp.mean(jnp.sum((y_true - y_pred)**2, axis=(1,2)))
#    return jnp.mean(norm_2_2_squared(y_true - y_pred))

@jax.jit
@jax.vmap
def norm_1_1(u):
    # u.shape == (seq_len, dim)
    return jnp.sum(jnp.linalg.norm(u, ord=1, axis=-1))

#@jax.jit
#def l1_loss(y_true, y_pred):
#    # y shape:  (bs, seq_len, output_dim)
##    return jnp.mean(jnp.sum(jnp.abs(y_true - y_pred), axis=(1,2)))
#    return jnp.mean(norm_1_1(y_true - y_pred))

@jax.jit
@jax.vmap
def norm_inf_inf(u):
    # u.shape == (seq_len, dim)
    return jnp.max(jnp.linalg.norm(u, ord=jnp.inf, axis=-1))

@partial(jax.jit, static_argnums=(2,))
def inference(u, state, model):
    return model.apply({"params": state.params}, u)

@jax.jit
@jax.vmap
def binary_cross_entropy(score, label):
    sigmoid_score = jax.nn.sigmoid(score)
    return -(label * jnp.log(sigmoid_score)
                      + (1 - label) * jnp.log(1 - sigmoid_score))
@jax.jit
@jax.vmap
def l1_loss(score, label):
    return jnp.abs(score - label)

@jax.jit
@jax.vmap
def l2_loss(score, label):
    return (score - label)**2

def get_params(model, state, u, y):
    if args.loss == "bc":
        L_l = jnp.sqrt(2).item()
        K_l = jnp.max(binary_cross_entropy(inference(u, state, model), y))
    elif args.loss == "l1":
        L_l = 1
        K_l = jnp.max(l1_loss(inference(u, state, model), y))

    return K_l, L_l

@jax.vmap
def compute_accuracy(score, label):
    return (score > 0.5) == label

def compute_accuracies(state, y, u, model):
    return jnp.mean(compute_accuracy(inference(u, state, model), y))


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

parser = argparse.ArgumentParser()

parser.add_argument("--use_wandb", type=str2bool, default=False,
                    help="log with wandb?")
parser.add_argument(
    "--wandb_project", type=str, default="spiral_v2", help="wandb project name"
)
parser.add_argument(
    "--wandb_entity",
    type=str,
    help="wandb entity name, e.g. username",
)

parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--T", type=int, default=5)
parser.add_argument("--N", type=int, default=200)
parser.add_argument("--radius_coeff", type=int, default=2)
parser.add_argument("--train_set_size", type=int, default=1600)

parser.add_argument("--n_layers", type=int, default=1)
parser.add_argument("--use_encoder", type=str2bool, default=True)
parser.add_argument("--use_decoder", type=str2bool, default=True)
parser.add_argument("--use_D", type=str2bool, default=False)
parser.add_argument("--save_models", type=str2bool, default=False)
parser.add_argument("--loss", type=str, default="bc", choices=["bc", "l1"])
parser.add_argument("--pooling", type=str, default="mean", choices=["mean",
                                                                    "last",
                                                                    "none"])
parser.add_argument("--sorting", type=str, default="none",
                    choices=["outward", "inward", "none"])

parser.add_argument("--delta", type=float, default=0.5)
parser.add_argument("--noise_scale", type=float, default=2.5)
parser.add_argument("--config_str", type=str, default="")
parser.add_argument("--optimizer", type=str, default="sgd",
                    choices=["sgd", "adam", "adamw"])
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--save_model_dir", type=str, default="")
parser.add_argument("--nu_log_coeff", type=float, default=1.0)

parser.add_argument("--n_x", type=int, default=4)
parser.add_argument("--n_u", type=int, default=2)
parser.add_argument("--n_y", type=int, default=1)

args = parser.parse_args()

#assert args.save_model_dir != ""

config = vars(args) | {"other": args.config_str}

set_seed(args.seed)

#wandb.init(mode="offline")
if args.use_wandb:
    wandb.init(
        project=args.wandb_project,
        job_type="model_training",
        config=config,
        name=f"N={args.N} rad={args.radius_coeff} | {args.config_str}"
        #entity="",
    )
else:
    wandb.init(mode="offline")

#jax.config.update("jax_disable_jit", True)

N = args.N # this is actually T
N_traj = 1000
radius = args.radius_coeff * torch.pi

N_test = 400
N_true = 30_000
N_valid = 5_000


train_set, _ = create_spiral_dataset(N=N, N_traj=args.train_set_size,
                                    radius=radius,
                                    noise_scale=args.noise_scale,
                                    sorting=args.sorting,
                                    T = args.T)

train_loader = DataLoader(
    train_set,
    batch_size=len(train_set),
    shuffle=True,
    pin_memory=True
    )

test_set, _ = create_spiral_dataset(N=N, N_traj=N_test,
                                    radius=radius,
                                    noise_scale=args.noise_scale,
                                    sorting=args.sorting,
                                    T = args.T)

test_loader = DataLoader(test_set,
                         batch_size=len(test_set),
                         shuffle=False,
                         pin_memory=True)


#train_test_ratio = 0.8
#lengths = [int(train_test_ratio * N_traj * 2), 2 * N_traj
#           - int(train_test_ratio * 2 * N_traj)]

#print(f"lengths = {lengths}")

true_set, _ = create_spiral_dataset(N=N, N_traj=N_true,
                                    radius=radius,
                                    noise_scale=args.noise_scale,
                                    sorting=args.sorting,
                                    T=args.T)


true_loader = DataLoader(true_set,
                         batch_size=len(true_set),
                         shuffle=False,
                         pin_memory=True)

valid_set, _ = create_spiral_dataset(N=N, N_traj=N_valid,
                                     radius=radius,
                                     noise_scale=args.noise_scale,
                                     sorting=args.sorting,
                                     T=args.T)



valid_loader = DataLoader(valid_set,
                          batch_size=len(valid_set),
                          shuffle=False,
                          pin_memory=True)

n_in = 2
n_x = args.n_x
n_u = args.n_u
assert n_u == n_in or args.use_encoder
n_y = 1
n_out = 1
assert n_y == n_out or args.use_decoder
D_HIDDEN = n_x # state dim of lru
D_MODEL = n_u # input / output dim of lru
OUTPUT_DIM  = n_out #output of the whole model after decoder
INPUT_DIM = n_in
BATCH_SIZE = 128
SEQ_LEN = N

lru= partial(
    LRU_spiral,
    d_hidden=D_HIDDEN, # hidden state dimension
    d_model=D_MODEL, # input and output dim of LRU unit
    r_min=0.0,
    r_max=1.0,
    use_D=args.use_D,
    nu_log_coeff=args.nu_log_coeff
)

model_cls = partial(
    BatchClassificationModel,
    lru=lru,
    d_output= OUTPUT_DIM, # output dim of the last decoder
    d_model=D_MODEL, # output dim of the initial encoder
    n_layers=args.n_layers,
    seq_layer_class=SequenceLayer_spiral,
    dropout=0.5,
    norm="none",
    multidim=1 + 0,
    use_decoder=args.use_decoder,
    use_encoder=args.use_encoder,
    pooling=args.pooling,
    return_score=True
)

key = random.PRNGKey(args.seed)
init_rng, dropout_rng, rng = random.split(key, num=3)
dummy_input = jnp.ones((BATCH_SIZE, SEQ_LEN, INPUT_DIM))

model = model_cls(training=True)
variables = model.init({"params": init_rng, "dropout": dropout_rng},
                       dummy_input)

#def load_model(path, model, variables):
#    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
#    empty_state = train_state.TrainState.create(
#        apply_fn=model.apply,
#        params=jax.tree_util.tree_map(np.zeros_like, variables['params']),
#        tx=optax.sgd(0.01))
#
#    target = {'model': empty_state}
#    restored = orbax_checkpointer.restore(path)#, item=target)
#    return restored['model']
#
#path = ""
#print(jax.tree_util.tree_map(np.zeros_like, variables['params']))
##tmp = load_model(path, model, variables)

K_u_train = jnp.max(norm_2_2(jnp.array(train_loader.dataset[:, :, :2].numpy())))
K_u_test = jnp.max(norm_2_2(jnp.array(test_loader.dataset[:, :, :2].numpy())))
K_u_valid = jnp.max(norm_2_2(jnp.array(valid_loader.dataset[:, :, :2].numpy())))
K_u_true = jnp.max(norm_2_2(jnp.array(true_loader.dataset[:, :, :2].numpy())))
print(f"{K_u_train=}, {K_u_test=}, {K_u_true=}, {K_u_valid=}")

params = variables["params"]

tx = optax.sgd(learning_rate=args.lr)
if args.optimizer == "adam":
    tx = optax.adam(learning_rate=args.lr)
   # tx = optax.adamw(learning_rate=args.lr,
   #                  weight_decay=0.7)
elif args.optimizer == "adamw":
    tx = optax.adamw(learning_rate=args.lr,
                     weight_decay=0.7)


state = train_state.TrainState.create(apply_fn=model.apply,
                                      params=params,
                                      tx=tx)





#def save_img(u, y, type_str, epoch):
#    path = os.path.join(root, type_str, str(epoch))
#    if not os.path.exists(path):
#        os.mkdir(path)
#
#    f_u = inference(u, state, model)
#
#    for cc in range(u.shape[0]):
#        plt.figure()
#        plt.plot(u[cc], label = "u")
#        plt.plot(y[cc], label = "y")
#        plt.plot(f_u[cc], label = "f(u)")
#        plt.legend()
#        plt.savefig(os.path.join(path, f"{cc}.png"))
#        plt.close()

loss_fn = binary_cross_entropy if args.loss == "bc" else l1_loss

@partial(jax.jit, static_argnums=(4,))
def train_step(state, y, u, rng, model):
    _loss = lambda p: jnp.mean(loss_fn(model.apply({"params": p}, u,
                                                   rngs={"dropout": rng}),
                                       y))
    #loss = _loss(state.params)
    loss, grads = jax.value_and_grad(_loss, has_aux=False)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

@partial(jax.jit, static_argnums=(3,))
def eval_step(state, y, u, model):
    loss = loss_fn(model.apply({"params": state.params}, u), y)
    return jnp.mean(loss)


now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#root = os.path.join("",
#                    f"{now}_N_{args.N}_rad_{args.radius_coeff}")
#os.mkdir(os.path.join(root, "train"))
#os.mkdir(os.path.join(root, "test"))
#os.mkdir(os.path.join(root, "true"))
#
#for data in train_loader:
#    assert data.shape[0] == len(train_loader.dataset)
#    u = jnp.array(data[:,:,:2].numpy())
#    y = jnp.array(data[:, 0, 2].numpy())
#    jnp.save(os.path.join(args.save_model_dir, "train_u.npy"), u)
#    jnp.save(os.path.join(args.save_model_dir, "train_y.npy"), y)
#
#for data in test_loader:
#    assert data.shape[0] == len(test_loader.dataset)
#    u = jnp.array(data[:,:,:2].numpy())
#    y = jnp.array(data[:, 0, 2].numpy())
#    jnp.save(os.path.join(args.save_model_dir, "test_u.npy"), u)
#    jnp.save(os.path.join(args.save_model_dir, "test_y.npy"), y)
#
#for data in valid_loader:
#    assert data.shape[0] == len(valid_loader.dataset)
#    u = jnp.array(data[:,:,:2].numpy())
#    y = jnp.array(data[:, 0, 2].numpy())
#    jnp.save(os.path.join(args.save_model_dir, "valid_u.npy"), u)
#    jnp.save(os.path.join(args.save_model_dir, "valid_y.npy"), y)
#
#for data in true_loader:
#    assert data.shape[0] == len(true_loader.dataset)
#    u = jnp.array(data[:,:,:2].numpy())
#    y = jnp.array(data[:, 0, 2].numpy())
#    jnp.save(os.path.join(args.save_model_dir, "true_u.npy"), u)
#    jnp.save(os.path.join(args.save_model_dir, "true_y.npy"), y)


for epoch in range(400):
    train_losses = jnp.array([])
    test_losses = jnp.array([])
    true_losses = jnp.array([])
    valid_losses = jnp.array([])
    metrics = {}
    cc = 0
    print(f"{epoch=}")
    model = model_cls(training=True)
    accs = {}
    for data in train_loader:
        assert data.shape[0] == len(train_loader.dataset)
        u = jnp.array(data[:,:,:2].numpy())
        y = jnp.array(data[:, 0, 2].numpy())
        rng, drop_rng = jax.random.split(rng)
#        save_img(u, y, "train", epoch)

        state, loss = train_step(state, y, u, drop_rng, model)
        train_losses = jnp.append(train_losses, loss)
        K_l_train, L_l_train = get_params(model_cls(training=False), state, u, y)
        accs['Train acc'] = compute_accuracies(state, y, u, model)
    metrics["Train loss"] = jnp.mean(train_losses)

    model = model_cls(training=False)
    for data in test_loader:
        assert data.shape[0] == len(test_loader.dataset)
        u = jnp.array(data[:,:,:2].numpy())
        y = jnp.array(data[:, 0, 2].numpy())
        f_u = inference(u, state, model)
        one_hot = jax.nn.one_hot(y, num_classes=f_u.shape[1])

#        save_img(u, y, "train", epoch)

        loss = eval_step(state, y, u, model)
        test_losses = jnp.append(test_losses, loss)
        K_l_test, L_l_test = get_params(model, state, u, y)
        accs['Test acc'] = compute_accuracies(state, y, u, model)
    metrics["Test loss"] = jnp.mean(test_losses)

    for data in true_loader:
        assert data.shape[0] == len(true_loader.dataset)
        u = jnp.array(data[:,:,:2].numpy())
        y = jnp.array(data[:, 0, 2].numpy())

        loss = eval_step(state, y, u, model)
        true_losses = jnp.append(true_losses, loss)
        K_l_true, L_l_true = get_params(model, state, u, y)
        accs['True acc'] = compute_accuracies(state, y, u, model)
    metrics["True loss"] = jnp.mean(true_losses)

    for data in valid_loader:
        assert data.shape[0] == len(valid_loader.dataset)
        u = jnp.array(data[:,:,:2].numpy())
        y = jnp.array(data[:, 0, 2].numpy())

#        save_img(u, y, "train", epoch)

        loss = eval_step(state, y, u, model)
        valid_losses = jnp.append(valid_losses, loss)
        K_l_valid, L_l_valid = get_params(model, state, u, y)
        accs['Valid acc'] = compute_accuracies(state, y, u, model)
    metrics["Valid loss"] = jnp.mean(valid_losses)


    K_u = max(K_u_train, K_u_test, K_u_true, K_u_valid)
    K_l = max(K_l_train, K_l_test, K_l_true, K_l_valid)
    print(f"{K_l_train=}, {K_l_test=}, {K_l_true=}, {K_l_valid=}")
    L_l = max(L_l_train, L_l_test, L_l_true, L_l_valid)
    print(f"{L_l_train=}, {L_l_test=}, {L_l_true=}, {L_l_valid=}")
    test_bound, metrics_bound = partial(get_bound_spiral,
                                        N = len(test_loader.dataset))(
        state,
        K_u=max(K_u_test, K_u_true),
        L_l=max(L_l_test, L_l_true),
        K_l=max(K_l_test, K_l_true),
        delta=args.delta,
        nu_log_coeff = args.nu_log_coeff
    )
    train_bound, _ = partial(get_bound_spiral,
                             N = len(train_loader.dataset))(state,
                                                            K_u=max(K_u_train,
                                                                    K_u_true),
                                                            L_l=max(L_l_true,
                                                                    L_l_train),
                                                            K_l=max(K_l_train,
                                                                    K_l_true),
                                                            delta=args.delta,
                                                            nu_log_coeff = args.nu_log_coeff)
    valid_bound, _ = partial(get_bound_spiral,
                             N = len(valid_loader.dataset))(state,
                                                            K_u=max(K_u_valid,
                                                                    K_u_true),
                                                            L_l=max(L_l_true,
                                                                    L_l_valid),
                                                            K_l=max(K_l_valid,
                                                                    K_l_true),
                                                            delta=args.delta,
                                                            nu_log_coeff = args.nu_log_coeff)
    metrics = metrics | metrics_bound

    metrics["test gen. gap"] = metrics["True loss"] - metrics["Test loss"]
    metrics["train gen. gap"] = metrics["True loss"] - metrics["Train loss"]
    metrics["valid gen. gap"] = metrics["True loss"] - metrics["Valid loss"]
    metrics["abs test gen. gap"] = abs(metrics["test gen. gap"])
    metrics["abs train gen. gap"] = abs(metrics["train gen. gap"])
    metrics["abs valid gen. gap"] = abs(metrics["valid gen. gap"])
    metrics["Train Bound"] = train_bound
    metrics["Test Bound"] = test_bound
    metrics["Valid Bound"] = valid_bound

    metrics["valid + bound"] = metrics["Valid loss"] + metrics["Valid Bound"]
    metrics["train + bound"] = metrics["Train loss"] + metrics["Train Bound"]
    metrics["test + bound"] = metrics["Test loss"] + metrics["Test Bound"]
    metrics = metrics  | accs

    pprint.pprint(metrics)
    nu_log = state.params['encoder']['layers_0']['seq']['nu_log']
    gamma_log = state.params['encoder']['layers_0']['seq']['gamma_log']
    #A_diag = jnp.exp(-jnp.exp(nu_log))
    #A_diag = nu_log
    A_diag = 1 / (args.nu_log_coeff + nu_log**2)# + 1j * jnp.exp(self.theta_log))
    A_matrix = jnp.diag(A_diag)
    print(A_matrix)
    eigvals = jnp.sort(jnp.linalg.eigh(A_matrix).eigenvalues)
    lambda_n, lambda_1 = eigvals[0], eigvals[-1]
    metrics["lambda_1"] = lambda_1
    metrics["lambda_n"] = lambda_n

    B_re = state.params['encoder']['layers_0']['seq']['B_re']
    B = B_re# * jnp.expand_dims(jnp.exp(gamma_log), axis=-1)
    C = state.params['encoder']['layers_0']['seq']['C_re']
    print(f"A = {A_matrix}")
    print(f"B = {B}")
    print(f"C = {C}")
    #print(f"A_gen = {A_gen}")
    #print(f"A = {state.params['encoder']['layers_0']['seq']['A']}")
    #print(f"B = {state.params['encoder']['layers_0']['seq']['B']}")
    #print(f"C = {state.params['encoder']['layers_0']['seq']['C']}")
    wandb.log(metrics)

    if args.save_models:
        ckpt = {'model': state}
        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        save_args = orbax_utils.save_args_from_target(ckpt)
        orbax_checkpointer.save(os.path.join(args.save_model_dir,
                                             f"model_{epoch}.ckpt"),
                                ckpt, save_args=save_args)