"""
Training script for Linear, Deep Linear, and Shallow NN models. Uses tinygrad.
"""
from collections import defaultdict
import datetime
import os
import pickle

import numpy as np
import tqdm

import tinygrad
from tinygrad import Tensor, nn


# This requires adding the tinygrad repo to your python path.
from extra.lr_scheduler import OneCycleLR
# Local import. Also This also requires adding PYTHONPATH=. to your environment variables.
import poisson_solutions
# PYTHONPATH=.:/path/to/tinygrad/repo/ python3.11 train_tiny_models.py


# Specify a default dtype. Change the tinygrad environment variable for default float when changing this
# too, because this variable is not plumbed into the variable initializers.
# Training in float64 allows the error to drop to below 10^-30, but is not necessary to observe the phenomenon.
dtype = tinygrad.dtypes.float32


# Two simple architectures.
class DoubleLinear:
    def __init__(self, in_dim: int, out_dim: int, hidden_dim: int=32, bias=True):
        self.L1 = nn.Linear(in_dim, hidden_dim, bias=bias)
        self.L2 = nn.Linear(hidden_dim, out_dim, bias=bias)

    def __call__(self, x: Tensor) -> Tensor:        
        return x.sequential([
            self.L1,
            self.L2
        ])

class MLP:
    def __init__(self, in_dim: int, out_dim: int, hidden_dim: int=32, bias=True, act=Tensor.relu):
        self.L1 = nn.Linear(in_dim, hidden_dim, bias=bias)
        self.L2 = nn.Linear(hidden_dim, out_dim, bias=bias)
        self.act = act

    def __call__(self, x: Tensor) -> Tensor:        
        return x.sequential([
            self.L1,
            self.act,
            self.L2
        ])
    

def get_train_batch(train_data: list[dict[str, np.array]],
                    batch_size: int=256) -> tuple[np.array]:
    if batch_size >= train_data['u'].shape[0]:  # Full batch training.
        return train_data['x'], train_data['f'], train_data['u']
    # Perform sampling to pick a batch without replacement. We do not worry about each epoch looping over
    # the training set because the datasets are "infinite" size.
    idcs = np.random.choice(train_data['u'].shape[0], batch_size, replace=False)
    idcs = Tensor(idcs)
    xs = train_data['x'][idcs, :]
    inputs = train_data['f'][idcs, :]
    labels = train_data['u'][idcs, :]
    return xs, inputs, labels


def train_a_model(
    G,
    train_data,
    all_data,
    wd=0.01,
    lr=1e-1,
    final_lr=1e-4,
    N_epochs=10_000,
    N_batch = None,
    description = '',
    save_metrics_history = False
):
    N_data = train_data["u"].shape[0]
    if N_batch is None: N_batch = N_data  # Full batch training.
    N_steps_per_epoch = N_data // N_batch
    N_steps = N_epochs * N_steps_per_epoch
    eval_freq = N_steps // 50

    N_grid = train_data["u"].shape[1]

    # Mean squared error loss.
    loss = lambda x, y: ((x - y) ** 2).mean()

    optimizer = nn.optim.AdamW(nn.state.get_parameters(G), lr=lr, weight_decay=wd)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=lr,
        div_factor=1.0,
        final_div_factor=lr / final_lr,
        total_steps=int(N_steps),
        pct_start=0.0,
    )

    @tinygrad.TinyJit
    @Tensor.train()
    def train_step(inputs, labels):
        y = G(inputs)
        loss_val = loss(y, labels)
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
        scheduler.step()
        return loss_val

    @tinygrad.TinyJit
    # @Tensor.test()
    def eval_step(inputs, labels):
        y = G(inputs)
        return loss(y, labels)


    metrics_history = defaultdict(list)
    test_losses = {}
    for i in (t:=tqdm.tqdm(range(0, N_steps), desc=description)):
        xpt, inputs, labels = get_train_batch(train_data, N_batch)
        loss_val_np = train_step(inputs, labels).numpy()
        if i % 10 * N_steps_per_epoch == 0:
            t.set_description(f"{description} Loss: {float(loss_val_np):1.3e}")
            
        metrics_history["train_loss"].append(loss_val_np)
        # Evaluate test loss across all other datasets.
        if save_metrics_history:  # Shave off some compute by not evaluating test history.
            if i % eval_freq == 0 or i == N_steps - 1:
                for p, test_data in all_data.items():
                    eval_loss = eval_step(test_data["f"], test_data["u"]).item()
                metrics_history[f"eval_{p}"].append(eval_loss)
    # Evaluate test loss across all other datasets one last time.
    for p, test_data in all_data.items():
        eval_loss = eval_step(test_data["f"], test_data["u"]).item()
        test_losses[p] = eval_loss
    return G, metrics_history, test_losses


#
# Script starts here.
#
def parse_arguments():
    import argparse
    parser = argparse.ArgumentParser(description='Train tiny models for spatial interpolation')
    parser.add_argument('--model_type', type=str, default='linear', 
                        choices=['linear', 'double_linear', 'mlp'],
                        help='Type of model to train')
    parser.add_argument('--n_seeds', type=int, default=5,
                        help='Number of random runs to perform. All are saved together with no reproducibility seed logged.')
    parser.add_argument('--N_grid', type=int, default=22,
                        help='Grid size for the spatial discretization')
    parser.add_argument('--save_path', type=str, default='.',
                        help='Path to save the trained models and metrics. Defaults to local directory.')
    return parser.parse_args()


# Parse command line arguments
args = parse_arguments()


# The three tinygrad architectures for commandline identification.
# The hyperparameters used for the presented results are hard coded for simplicity.
def model_factory(model_type: str, N_grid: int):
    match model_type:
        case "linear":
            G = nn.Linear(N_grid, N_grid, bias=False)
            G.weight[:,:] = 0.0  # Initialize the weights to zero.
            return G
        case "double_linear":
            return DoubleLinear(N_grid, N_grid, hidden_dim=100)
        case "mlp":
            # A large ff is needed because the leakyrelu makes it harder to fit reliable.
            return MLP(N_grid, N_grid, hidden_dim=1024, act=Tensor.leaky_relu)


def get_learning_schedule(model_type: str):
    match model_type:
        case "linear":
            return {'lr': 1e-1, 'final_lr': 1e-5, 'N_epochs': 2_000, 'N_batch': None}
        case "double_linear":
            return {'lr': 1e-1, 'final_lr': 1e-5, 'N_epochs': 2_000, 'N_batch': None}
        case "mlp":
            # The MLP needs a slower learning rate and stochasic gradient descent to fit to a very low
            # error rate. With full batch training, the error is 1e-9.
            return {'lr': 1e-2, 'final_lr': 1e-6, 'N_epochs': 5_000, 'N_batch': 256}


# Load the training data. Only N_grid = 22 by default.
N_grid = args.N_grid
train_data = poisson_solutions.create_dataset_dict(root_spread=0.0)

# Convert the training data to tinygrad Tensors.
train_data_tiny = defaultdict(dict)
for ng, dat in train_data.items():
    for p, d in dat.items():
        train_data_tiny[ng][p] = { kk:Tensor(vv, dtype=dtype) for kk, vv in d.items() }


print(f"Kicking off a sweep of {args.n_seeds} repeats of {args.model_type} on a {N_grid} point grid.")
runs = []
A_matrices = []
n_seeds = args.n_seeds
subset = ['fem', 1, 4, 7, 's1', 's4', 's7', 'c1', 'c4', 'c7']
all = train_data_tiny[N_grid].keys()
# Defaults to running one seed on all function spaces.
function_spaces_to_run = all
for p_train in function_spaces_to_run:
    for seed in range(n_seeds):
        td = train_data_tiny[N_grid][p_train]
        G = model_factory(args.model_type, N_grid)
        wd = 0.0  # No weight decay.
        l_hyper = get_learning_schedule(args.model_type)
        G, metrics_history, test_losses = train_a_model(
            G, td, train_data_tiny[N_grid], wd,
            lr=l_hyper['lr'], final_lr=l_hyper['final_lr'], N_epochs=l_hyper['N_epochs'],
            N_batch=l_hyper['N_batch'],
            description=f"p_train={p_train}")
        runs.append((p_train, test_losses))
        if args.model_type == 'linear':
            A_matrix = G.weight.numpy()
        else:
            A_matrix = G(Tensor.eye(N_grid)).numpy()
        A_matrices.append((p_train, A_matrix))



# Generate a timestamp in the format YYYYMMDD_HHMMSS
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(args.save_path, f'tiny_{args.model_type}_poisson_runs_{timestamp}.pkl')
with open(filename, 'wb') as f:
    pickle.dump(runs, f)
print(f"Saved {len(runs)} runs to {filename}")
filename = os.path.join(args.save_path, f'tiny_{args.model_type}_poisson_A_matrices_{timestamp}.pkl')
with open(filename, 'wb') as f:
    pickle.dump(A_matrices, f)
print(f"Saved {len(A_matrices)} A matrices to {filename}")
