import hydra
from omegaconf import DictConfig, open_dict
import numpy as np
import torch
from tqdm import tqdm
import wandb

from src.utils import (
    seed_everything,
    print_flattened_config,
    flatten_config,
    best_epoch_select_strategy,
    get_class_from_path,
    move_to_device,
    get_dataset,
    generate_random_string,
)
from src.loss.loss import compute_loss
from src.logger import EpochLogger


class Trainer:
    batch_on_gpu = None

    def __init__(self, cfg: DictConfig, split_index: int):
        self.cfg = cfg
        self.split_index = split_index
        self.dtype = get_class_from_path(cfg.dtype)
        seed_everything(cfg.seed)

        self._setup_dataset()
        self._setup_model()
        self._setup_optimizer()
        self._setup_loggers()

        print_flattened_config(self.cfg)

    def _setup_dataset(self):
        self.dataset = get_dataset(
            data_root=self.cfg.data_root,
            dataset_cfg=self.cfg.dataset,
            eigen_subset_cfg=self.cfg.eigen_subset,
            split_index=self.split_index,
        )

        with open_dict(self.cfg):
            self.cfg.model.split_index = self.split_index
            self.cfg.model.dim_in = self.dataset.data.x.shape[1]
            self.cfg.model.num_eigenvecs = self.dataset.eigenvecs.shape[1]

            # if doing binary classification, return 1d outputs
            if self.dataset.data.y.max().item() + 1 == 2:
                self.cfg.model.dim_out = 1
            else:
                self.cfg.model.dim_out = self.dataset.data.y.max().item() + 1

            print(f"Metric: {self.cfg.dataset.metric}")

    def _setup_model(self):
        self.model = get_class_from_path(self.cfg.model.base_class)(self.cfg).to("cuda")

        # Store num of parameters in config
        num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        with open_dict(self.cfg):
            self.cfg.model.num_params = num_params

    def _setup_optimizer(self):

        # Optimizer
        self.optimizer = get_class_from_path(self.cfg.optimizer.name)(
            self.model.parameters(),
            **self.cfg.optimizer.params,
        )

        # Scheduler
        self.scheduler = get_class_from_path(self.cfg.scheduler.name)(
            self.optimizer,
            **self.cfg.scheduler.params,
        )

    def _setup_loggers(self):
        self.wandb = wandb.init(**self.cfg.wandb, dir=self.cfg.out_dir)
        self.wandb.config.update(flatten_config(self.cfg))

        self.loggers = {
            "Train": EpochLogger(self.cfg, "Train"),
            "Val": EpochLogger(self.cfg, "Val"),
            "Test": EpochLogger(self.cfg, "Test"),
        }

    def cleanup(self):
        self.wandb.finish()

    def _move_all_data_to_gpu(self):
        data = self.dataset.data
        batch = {}
        batch["x_feat"] = data.x
        batch["graph_pos"] = data.eigenvecs
        batch["eigenvals"] = data.eigenvals
        batch["edge_index"] = data.edge_index
        batch["batch_index"] = torch.zeros(
            data.num_nodes, dtype=torch.long
        )  # for node classification, we only have 1 graph
        batch["node_seqlen"] = [data.num_nodes]
        # Masks
        batch["train_mask"] = data.train_mask
        batch["val_mask"] = data.val_mask
        batch["test_mask"] = data.test_mask
        # Readout
        batch["node_ids"] = torch.arange(data.num_nodes)
        batch["task_label"] = data.y

        # Move to GPU
        self.batch_on_gpu = move_to_device(batch, "cuda")

    def _nodeclass_epoch(self):
        batch = self.batch_on_gpu

        # -- train
        self.model.train()
        self.optimizer.zero_grad()
        with torch.amp.autocast("cuda", dtype=self.dtype):
            out, label = self.model(batch)
            mask = batch["train_mask"]
            loss, pred = compute_loss(out[mask], label[mask], self.cfg)
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()
        self.loggers["Train"].update_stats(label[mask].cpu(), pred.cpu(), loss.item())

        # -- val and test
        with torch.no_grad():
            self.model.eval()
            with torch.amp.autocast("cuda", dtype=self.dtype):
                out, label = self.model(batch)
            mask = batch["val_mask"]
            loss, pred = compute_loss(out[mask], label[mask], self.cfg)
            self.loggers["Val"].update_stats(label[mask].cpu(), pred.cpu(), loss.item())

            mask = batch["test_mask"]
            loss, pred = compute_loss(out[mask], label[mask], self.cfg)
            self.loggers["Test"].update_stats(label[mask].cpu(), pred.cpu(), loss.item())

    def train(self):
        self._move_all_data_to_gpu()

        perf_dict = {"Train": [], "Val": [], "Test": []}
        stats = []
        for epoch in (pbar := tqdm(range(self.cfg.max_epochs), desc="Epochs", position=0, leave=True)):
            self._nodeclass_epoch()

            log_stats = {"epoch": epoch}
            for logger in self.loggers.values():
                this_epoch = logger.write_epoch()
                perf_dict[logger.name].append(this_epoch)
                log_stats.update(this_epoch)
            log_stats.update({"lr": self.optimizer.param_groups[0]["lr"]})

            # find best epoch
            best_epoch = best_epoch_select_strategy(
                perf=perf_dict["Val"],  # best selection based on validation results
                val_name="Val",
                metric=self.cfg.dataset.metric,
                strategy=self.cfg.best_selection.strategy,
                agg=self.cfg.best_selection.agg_param,
            )

            # add best metrics using best epoch
            for logger in self.loggers.values():
                for key, value in perf_dict[logger.name][best_epoch].items():
                    log_stats[f"Best/{key}"] = value

            pbar.set_description(
                f"Epoch {epoch} | "
                f"LR: {log_stats['lr']:.2e} | "
                f"Train Loss: {log_stats['Train/average_loss']:.4f} | "
                f"Train {self.cfg.dataset.metric}: {log_stats[f'Train/{self.cfg.dataset.metric}']:.4f} | "
                f"Val {self.cfg.dataset.metric}: {log_stats[f'Val/{self.cfg.dataset.metric}']:.4f} | "
                f"Test {self.cfg.dataset.metric}: {log_stats[f'Test/{self.cfg.dataset.metric}']:.4f} "
            )

            self.wandb.log(log_stats)
            stats.append(log_stats)

        return stats


def run_training(cfg: DictConfig):

    # Generate a group_id to group splits in Wandb
    group_id = generate_random_string()
    with open_dict(cfg):
        cfg.group_id = group_id
    # Train 1 split at a time
    split_stats = [None for _ in cfg.dataset.split_index]
    for split_index in cfg.dataset.split_index:

        trainer = Trainer(cfg, split_index)
        stats = trainer.train()
        trainer.cleanup()

        split_stats[split_index] = stats
        print(f"Best/Val/{cfg.dataset.metric} = {stats[-1][f'Best/Val/{cfg.dataset.metric}']}")
        print(f"Best/Test/{cfg.dataset.metric} = {stats[-1][f'Best/Test/{cfg.dataset.metric}']}")

    split_val_metric = []
    split_test_metric = []
    for stats in split_stats:
        split_val_metric.append(stats[-1][f"Best/Val/{cfg.dataset.metric}"])
        split_test_metric.append(stats[-1][f"Best/Test/{cfg.dataset.metric}"])

    print()
    print(f"Mean Best/Val = {np.mean(split_val_metric)} +/- {np.std(split_val_metric)}")
    print(f"Mean Best/Test = {np.mean(split_test_metric)} +/- {np.std(split_test_metric)}")

    return split_stats


@hydra.main(config_path="./configs", config_name="train")
def main(cfg):
    return run_training(cfg)


if __name__ == "__main__":
    main()
