from typing import Sequence

from dataclasses import dataclass
import itertools
import os
import datetime
import pickle

import torch
from torch.utils.data import DataLoader, Dataset

from neuralop.models import FNO
from neuralop import Trainer
from neuralop.training import AdamW
from neuralop.utils import count_model_params
from neuralop import LpLoss

import poisson_solutions
# Alternatives are 'cpu' and 'cuda'.
device = 'mps'

N_data = 10_000
N_grid = 22
train_data = poisson_solutions.create_dataset_dict()


class DictionaryDataset(Dataset):
    "Wrap the numpy data into the Dataset class used by NeuralOp."
    def __init__(self, data_dict):
        # Store the dictionary of numpy arrays
        self.data_dict = data_dict
        # Get the length from the first array (assuming all arrays have same length)
        self.length = len(next(iter(data_dict.values())))
        # Verify all arrays have the same length
        for key, array in data_dict.items():
            assert len(array) == self.length, f"Array with key '{key}' has different length"
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        # Create a new dictionary with the idx-th element of each array
        # sample = {key: torch.from_numpy(array[idx]) for key, array in self.data_dict.items()}
        # shp = (1, 1, self.data_dict['f'][idx].shape[-1])  # 2d
        shp = (1, self.data_dict['f'][idx].shape[-1])  # 1d

        sample = {'x': torch.from_numpy(self.data_dict['f'][idx].astype('float32')).reshape(shp),
                  'y': torch.from_numpy(self.data_dict['u'][idx].astype('float32')).reshape(shp)}
        return sample


class MSELossWrapper(torch.nn.MSELoss):
    """Wrap MSELoss for the NeuralOp API."""
    def __init__(self, *args, **kwargs):
        super(MSELossWrapper, self).__init__(reduction='sum', *args, **kwargs)
    
    def __call__(self, y_pred, y, **kwargs):
        return super(MSELossWrapper, self).__call__(y_pred, y)


class MyLinear(torch.nn.Module):
    "A Linear model to test the NeuralOp API and training loop."
    def __init__(self):
        super(MyLinear, self).__init__()
        self.linear = torch.nn.Linear(N_grid, N_grid, bias=False)
        self.linear.weight.data = torch.zeros(N_grid, N_grid)
    def forward(self, x, **kwargs):
        return self.linear(x)


def train_a_model(model, p_train, loaders, n_epochs=50, lr=1e-2, wd=1e-3, schedule=Sequence[float], which_loss='l2'):
    model = model.to(device)
    train_loader = loaders[p_train]
    test_loaders = loaders

    n_params = count_model_params(model)
    print(f"\nOur model has {n_params} parameters.")

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=wd)
    # Step learning rate scheduler that decreases the learning rate at specific epochs
    if schedule is not None:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(epoch*n_epochs) for epoch in schedule],  # Epochs at which to decrease learning rate
            gamma=0.1,  # Multiplicative factor of learning rate decay
        )
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

    mse_loss = MSELossWrapper()
    l2loss = LpLoss(d=2, p=2)

    train_loss = l2loss if which_loss == 'l2' else mse_loss
    eval_losses = {"mse": mse_loss, "l2": l2loss}

    trainer = Trainer(
        model=model,
        n_epochs=n_epochs,
        device=device,
        wandb_log=False,
        eval_interval=10,
        use_distributed=False,
        verbose=True,
    )

    metrics = trainer.train(
        train_loader=train_loader,
        test_loaders={},
        optimizer=optimizer,
        scheduler=scheduler,
        regularizer=False,
        training_loss=train_loss,
        eval_losses=eval_losses,
    )
    with torch.no_grad():
        metrics = trainer.evaluate_all(
            epoch=n_epochs, eval_losses=eval_losses, test_loaders=loaders
        )
    return metrics


@dataclass
class ModelRun:
    p_train: str
    hyperparams: dict[str, int]
    metrics: dict



N_epochs = 10_000

loaders = {key: DataLoader(DictionaryDataset(data_dict), batch_size=1024, shuffle=True)
           for key, data_dict in train_data[22].items()}

runs = []
simple_results = []
models = []
# subset = ['fem', 1, 4, 7, 's1', 's4', 's7', 'c1', 'c4', 'c7']
# N_seeds = 4  # +1 from the other sweep.
# Defaults to running one seed on all function spaces.
all = train_data[N_grid].keys()
N_seeds = 1
function_spaces_to_run = all

for n_modes, n_hidden, n_layers, projection_channel_ratio in itertools.product(
    [64,], [32], [2], [1]
):
    for p_train, seed in itertools.product(function_spaces_to_run, range(N_seeds)):
        print("Working on ", p_train, seed)
        model = FNO(n_modes=(n_modes,),
                    in_channels=1,
                    out_channels=1,
                    hidden_channels=n_hidden,
                    n_layers=n_layers,
                    projection_channel_ratio=projection_channel_ratio)
        metrics = train_a_model(model, p_train, loaders,
                                n_epochs=N_epochs, lr=1e-3, wd=1e-4, schedule=[ 0.5, 0.8])
        metrics = {k: v.item() for k,v in metrics.items()}
        # plt.title(f'Training on {p_train}')
        hyperparams = {
            'n_modes': n_modes,
            'n_hidden': n_hidden,
            'n_layers': n_layers,
            'projection_channel_ratio': projection_channel_ratio
        }
        runs.append(ModelRun(p_train, hyperparams, metrics))
        simple_results.append((p_train, metrics))
        models.append(model)

save_path = '.'

# Generate a timestamp in the format YYYYMMDD_HHMMSS
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(save_path, f'neuralop_poisson_runs_{timestamp}.pkl')
with open(filename, 'wb') as f:
    pickle.dump(simple_results, f)
print(f"Saved {len(simple_results)} runs to {filename}")
filename = os.path.join(save_path, f'neuralop_poisson_models_{timestamp}.pkl')
with open(filename, 'wb') as f:
    pickle.dump(models, f)
print(f"Saved {len(models)} models to {filename}")
