import argparse
import itertools
from datetime import datetime

import torch
import wandb
from dawgz import Job, schedule

from neural_mpm.data.data_manager import DataManager
from neural_mpm.experiments import experiments
from neural_mpm.nn import UNet, FNO, FFNO, create_model
from neural_mpm.train import train
from neural_mpm.util.model_logger import ModelLogger

"""
TODO:
- Toggleable noise
- Loss type
- LR Schedulers

maybe change dict['...'] by dict.get('...', default_value)

wandb.init etc should be in train, we could try to merge them.

"""
class Experiment:
    def __init__(self, experiment_name):
        self.backend = None
        self.config_dicts = None
        self.experiment_name = experiment_name

        # Load default parameters
        self.experiment_params = experiments.DEFAULT

        # Override with experiment-specific parameters
        experiment_params = None
        for key, value in experiments.__dict__.items():
            if key == experiment_name:
                experiment_params = value
                break
        if experiment_params is None:
            raise ValueError(f"Experiment {experiment_name} not found.")
        self.experiment_params.update(experiment_params)

        # Generate job configs.
        self.gen_config_dicts()

    def training_job(self, config_dict):
        # if torch.cuda.is_available():
        #     torch.set_default_device("cuda")

        model = create_model(config_dict['model'], config_dict)
        if torch.cuda.is_available():
            model = model.cuda()

        optimizer = torch.optim.Adam(model.parameters(), config_dict['lr'])

        schedulers = None
        # TODO more sophisticated scheduler settings
        if config_dict['use_schedulers']:
            warmup_end = 100
            cosine_start = 1000
            total_iters = int(1e5)
            warmup_end = min(warmup_end, cosine_start)
            linear_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=1e-2,
                end_factor=1.0,
                total_iters=warmup_end,
            )
            cos_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=total_iters - cosine_start,
                eta_min=config_dict['min_lr'],
                last_epoch=-1,
            )
            schedulers = {
                "linear": linear_scheduler,
                "cosine": cos_scheduler,
                "warmup_end": warmup_end,
                "cosine_start": cosine_start,
                "total_iters": total_iters,
            }

        datamanager = DataManager(
            config_dict['data'],
            batch_size=config_dict['batch_size'],
            dim=2,
            grid_size=config_dict['grid_size'],
            steps_per_call=config_dict['steps_per_call'],
            autoregressive_steps=config_dict['autoregressive_steps'],
            sims_in_memory=config_dict['sims_in_memory'],
        )


        now = datetime.now()
        formatted_date_time = now.strftime("(%d_%m) %H:%M:%S")
        wandb_config = config_dict.copy()
        wandb_config.update({
            "num_steps": datamanager.sim_length,
            "length": 0.015,
            "interp_fct": "constant",
            "experiment": self.experiment_name,
        })

        run_name = wandb_config.pop('run_name', formatted_date_time)
        run_name = str(run_name)

        print(f"Experiment {run_name} starting..")

        use_wandb = True

        project_name = "experiments"
        if 'project' in config_dict:
            project_name = config_dict['project']

        if use_wandb:
            wandb.init(
                project=project_name,
                entity="neuralmpm",
                name=run_name,
                config=wandb_config,
            )
            wandb.watch(model, log="all", log_freq=4)

        model_logger = ModelLogger(
            config_dict['data'].rstrip('/').split('/')[-1],
            config_dict,
            save_interval=config_dict.get('save_every', 10),
        )

        train(
            model=model,
            datamanager=datamanager,
            optimizer=optimizer,
            schedulers=schedulers,
            use_wandb=use_wandb,
            epochs=config_dict['epochs'],
            passes_over_buffer=config_dict['passes_over_buffer'],
            model_logger=model_logger,
            particle_noise=config_dict['particle_noise'],
            grid_noise=config_dict['grid_noise'],
            progress_bars=self.backend == "async",
        )

    def gen_config_dicts(self):
        exp_type = self.experiment_params.pop('exp_type', 'explicit')

        for key, value in self.experiment_params.items():
            if not isinstance(value, tuple):
                self.experiment_params[key] = (value,)

        self.config_dicts = []

        if exp_type == 'explicit':
            nb_runs = max([len(v) for v in self.experiment_params.values()])
            varying_params = []
            for key, value in self.experiment_params.items():
                if len(value) == 1:
                    self.experiment_params[key] = value * nb_runs
                else:
                    if isinstance(value[0], dict):
                        formatted_dicts = [
                            ' '.join([f"{k[0]}={v}" for k, v in v.items()])
                            for v in value
                        ]
                        varying_params.append(formatted_dicts)
                    else:
                        varying_params.append([f"{key}={v}" for v in value])
            run_names = list(zip(*varying_params))
            run_names = [' '.join(run_name) for run_name in run_names]

            for i in range(nb_runs):
                config_dict = {}
                for key, value in self.experiment_params.items():
                    config_dict[key] = value[i]
                if run_names and not 'run_name' in config_dict:
                    config_dict['run_name'] = run_names[i]
                self.config_dicts.append(config_dict)
        else:  # exp_type == 'combi'
            param_combinations = itertools.product(
                *self.experiment_params.values())

            # TODO Run name

            for params in param_combinations:
                config_dict = dict(zip(self.experiment_params.keys(), params))
                self.config_dicts.append(config_dict)

    def run(self, data, backend="slurm", save_every=10, cluster=None):
        self.backend = backend

        # Set the data of each config_dict
        for config_dict in self.config_dicts:
            if data is not None:
                config_dict['data'] = data
            if save_every is not None:
                config_dict['save_every'] = save_every

        num_jobs = len(self.config_dicts)

        if cluster is None or cluster == 'custom':

            j = Job(
                lambda i: self.training_job(self.config_dicts[i]),
                name=self.experiment_name,
                array=num_jobs,
                partition="gpu",
                cpus=2,
                gpus=1,
                ram="12GB",
                time="14-00:00",
            )

            schedule(
                j,
                backend=backend,
                export="ALL",
                shell="/bin/sh",
                env=["export WANDB_SILENT=true"],
                time="14-00:00",
            )

        elif cluster == 'other':

            j = Job(
                lambda i: self.training_job(self.config_dicts[i]),
                name=self.experiment_name,
                array=num_jobs,
                partition="gpu",
                cpus=4,
                gpus=1,
                ram="56GB",
                time="2-00:00",
            )

            schedule(
                j,
                backend=backend,
                export="ALL",
                shell="/bin/sh",
                env=["export WANDB_SILENT=true"],
                time="2-00:00",
            )

        print(f"[{backend}]: Scheduled {num_jobs} job"
              f"{'s' if num_jobs > 1 else ''}.")


def main():
    parser = argparse.ArgumentParser("Neural MPM Experiment Runner")
    parser.add_argument('-e', '--experiment', type=str, required=True)
    parser.add_argument('-d', '--data', type=str, help="Override "
                                                 "experiment's data.")
    parser.add_argument('-l', "--local", action="store_true")
    # optional argument for save-every
    parser.add_argument('-s', "--save-every", type=int, default=60)
    parser.add_argument("--cluster", type=str, help='Which cluster to run on')

    args = parser.parse_args()
    backend = "async" if args.local else "slurm"
    save_every = args.save_every
    experiment = Experiment(args.experiment)
    experiment.run(args.data, backend=backend, save_every=save_every, cluster=args.cluster)

if __name__ == '__main__':
    main()