import datetime
from pathlib import Path
import torch
import lightning as pl
import numpy as np
from model import GraphNetModel
from data import DataModulePath, DataModuleTree
import json
from sklearn.model_selection import KFold
from torch.utils.data import Subset, Dataset
from torch_geometric.data import Data

from lightning.pytorch.loggers import TensorBoardLogger


from dataset_opts import DATASET_OPTS
from data import caterpillar_collate




class dataset_from_pt(Dataset):
    def __init__(self, dataset_name, subname=None, subset=None, height=None, mode=[], split=None, force_suffix=None):
        super().__init__()

        root = Path('data')
        root = root / dataset_name
        if subname is not None:
            root = root / subname
        
        if subset is not None:
            root = root / subset

        if isinstance(force_suffix, str):
            self.processed_dir = root / f"processed-{force_suffix}"
        else: 
            suffix = ""
            for m in mode:
                suffix += f"_{m}"

            self.processed_dir = root / f"processed-H{height}{suffix}"

        if split is None:
            path = self.processed_dir / f"data.pt"
        else:
            path = self.processed_dir / f"{split}.pt"

        data_list = torch.load(path, weights_only=False)
        self.data = [d for d in data_list if isinstance(d, Data)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]    
    

def get_monitored_metric(task):
    if task.startswith("auc"):
        return "val_auc"
    elif task.startswith("classification"):
        return "val_acc"
    elif task.startswith("regression"):
        return "val_mae"
    else:
        raise ValueError(f"Invalid task: {task}")


def get_mode(task):
    if task.startswith("auc"):
        return "max"
    elif task.startswith("classification"):
        return "max"
    elif task.startswith("regression"):
        return "min"
    else:
        raise ValueError(f"Invalid task: {task}")
    


def main(
        dataset,  
        threads, 
        height,
        **flags
        ):

    np.random.seed(42)
    torch.manual_seed(42)
    torch.set_num_threads(threads)

    model_type = flags['model']
    num_layers = flags['layers']


    if flags["reps"] is None and flags["kfold"] is None:
        raise ValueError("Either `reps` or kfold must be provided.")


    if ':' in dataset:
        dataset_name, subname = dataset.split(':')
    else:
        dataset_name, subname = dataset, None


    opts = DATASET_OPTS[dataset_name]
    print(opts)

    if opts.get('__subname', False):
        if subname is None:
            raise ValueError(f"Dataset {dataset_name} requires a subname.")
        opts = opts.get(subname, {}) | {'name': subname}
    elif '__subname' in opts:
        del opts['__subname']

    assert "__task" in opts, f"Dataset {dataset_name}:{subname} does not have a `__task` attribute {list(opts.keys())}"
    task = opts["__task"]
    subset = "subset" if opts.get("subset", False) else None

    if model_type == "path":
        model_opts = opts.get("__model_opts", {}) 
        model_opts |= flags['opts'] 
        model_opts |= {"layers": num_layers}
        
        print("Model opts: ", model_opts)
        print("Optimizer opts: ", flags["optim"])

    elif model_type == "tree":
        model = GraphNetModel(
            model_type="tree",
            model_opts={
                "layers": num_layers,
            },
        )
    else:
        raise ValueError(f"Invalid model type: {model_type}")

    # print("Models is composed")
    # print(model)

    # if model_type == "path":
    #     datamodule = DataModulePath(
    #         num_workers=threads,
    #         batch_size=flags["batch_size"],
    #         mode=flags["mode"],
    #         dataset_name=dataset_name,
    #         subname=subname,
    #         height=height,
    #         prepare_splits=opts["__splits"] is not None,
    #     )
    # elif model_type == "tree":
    #     assert flags["dataset"] == "ZINC/subset"
    #     datamodule = DataModuleTree(
    #         num_workers=threads,
    #         batch_size=flags["batch_size"],
    #     )
    # else:
    #     raise ValueError(f"Invalid model type: {model_type}")


    if torch.cuda.is_available():
        dev_args = dict(
            devices=torch.cuda.device_count(),
            accelerator="gpu",
        )
    else:
        dev_args = dict(accelerator="cpu")

    dev_args["max_epochs"] = flags["epochs"]

    suffix = ""
    for mode in flags["mode"]:
        suffix += f"_{mode}"

    logs_folder = Path("logs_tb") / dataset_name 
    if subname is not None:
        logs_folder = logs_folder / subname
    elif opts.get("subset", False):
        logs_folder = logs_folder / "subset"
    else:
        logs_folder = logs_folder / "full"

    print("Logging to ", logs_folder)

    if isinstance(flags["name"], str):
        name = flags["name"]
    else:
        name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    

    dev_args["gradient_clip_val"] = 1.0
    dev_args["logger"] = TensorBoardLogger(
        save_dir=logs_folder,
        name=f"{model_type}_{name}{suffix}",
    )

    dataloader_settings = dict(
        batch_size=flags["batch_size"],
        num_workers=threads,
        collate_fn=caterpillar_collate,
    )






    
    if isinstance(flags["kfold"], int):
        # we are doing kfold,
        # when the dataset has not been split
        # and is usually of a limited size

        dataset = dataset_from_pt(
            dataset_name=dataset_name,
            subname=subname,
            subset=subset,
            height=height,
            mode=flags["mode"],
        )

        output_file = logs_folder / f"{dataset.processed_dir.name}_{model_type}_{name}.txt"
        print("Output file: ", output_file)

        seed = 42
        results = []
        kf = KFold(n_splits=flags["kfold"], shuffle=True, random_state=seed)
        for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
            print(f"\n--- Fold {fold + 1} ---")

            # complete reproducibility
            pl.seed_everything(seed + fold)  

            model = GraphNetModel(
                task=task,
                model_type="path",
                model_opts=model_opts,
                optim_opts=flags["optim"],
            )

            train_dataset = Subset(dataset, train_idx)
            val_dataset = Subset(dataset, val_idx)

            train_loader = torch.utils.data.DataLoader(
                train_dataset, 
                **dataloader_settings,
                shuffle=True
            )
            val_loader = torch.utils.data.DataLoader(
                val_dataset, 
                **dataloader_settings,
                shuffle=False
            )

            checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
                monitor=get_monitored_metric(task),
                mode=get_mode(task),  
                save_top_k=1,
                filename='best-checkpoint',
                save_weights_only=True,
            )

            early_stopping_callback = pl.pytorch.callbacks.EarlyStopping(
                monitor='val_loss',
                mode='min',
                patience=20,
            )
        
            trainer = pl.Trainer(
                callbacks=[
                    checkpoint_callback, 
                    early_stopping_callback,
                ],
                **dev_args,
            )
            trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            results.append(checkpoint_callback.best_model_score)
            # Path(checkpoint_callback.best_model_path).unlink()

            mean_acc = sum(results) / len(results)
            std_acc = (sum((x - mean_acc)**2 for x in results) / len(results))**0.5
            print(f"\nRunning {fold+1} Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

        print(f"\nFinal {flags['kfold']}-Fold Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
        print("Saving to ", output_file)
        with open(output_file, "a") as f:
                f.write(f"{flags['kfold']}\t{mean_acc:.4f}\t{std_acc:.4f}\n")
    


    elif isinstance(flags["reps"], int):
        # we are doing repetitions for publicly    
        # available splits

        train_dataset = dataset_from_pt(
            dataset_name=dataset_name,
            subname=subname,
            subset=subset,
            height=height,
            mode=flags["mode"],
            split='train'
        )

        val_dataset = dataset_from_pt(
            dataset_name=dataset_name,
            subname=subname,
            subset=subset,
            height=height,
            mode=flags["mode"],
            split='val',
        )

        output_file = logs_folder / f"{train_dataset.processed_dir.name}_{model_type}_{name}.txt"
        print("Output file: ", output_file)

        seed = 42
        results = []
        for rep in range(flags["reps"]):

            # complete reproducibility
            pl.seed_everything(seed + rep)  

            model = GraphNetModel(
                task=task,
                model_type="path",
                model_opts=model_opts,
                optim_opts=flags["optim"],
            )

            train_loader = torch.utils.data.DataLoader(
                train_dataset, 
                **dataloader_settings,
                shuffle=True
            )

            val_loader = torch.utils.data.DataLoader(
                val_dataset, 
                **dataloader_settings,
                shuffle=False
            )

            checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
                monitor=get_monitored_metric(task),
                mode=get_mode(task),  
                save_top_k=1,
                filename='best-checkpoint',
                save_weights_only=True,
            )

            early_stopping_callback = pl.pytorch.callbacks.EarlyStopping(
                monitor='val_loss',
                mode='min',
                patience=20,
            )
        
            trainer = pl.Trainer(
                callbacks=[
                    checkpoint_callback, 
                    early_stopping_callback,
                ],
                **dev_args,
            )
            trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            results.append(checkpoint_callback.best_model_score)
            # Path(checkpoint_callback.best_model_path).unlink()

            mean_acc = sum(results) / len(results)
            std_acc = (sum((x - mean_acc)**2 for x in results) / len(results))**0.5
            print(f"\nRunning {rep+1} Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

        print(f"\nFinal {flags['reps']}-Fold Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
        print("Saving to ", output_file)
        with open(output_file, "a") as f:
                f.write(f"{flags['reps']}\t{mean_acc:.4f}\t{std_acc:.4f}\n")

    else:
        raise ValueError("Invalid value for `reps` or `kfold`")




if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--threads", "-j", type=int, default=12)
    parser.add_argument(
        "--model", "-m", type=str, default="path", choices=["path", "tree"]
    )
    parser.add_argument("--batch_size", "-b", type=int, default=4)
    parser.add_argument("--optim", type=json.loads, default={
        "lr": 1e-3,
        "weight_decay": 1e-4,
    })
    parser.add_argument("--opts", "-o", type=json.loads, default="{}")
    parser.add_argument("--layers", "-l", type=int, default=6)
    parser.add_argument("--epochs", "-e", type=int, default=10)
    parser.add_argument("--name", "-n", type=str, default=None)
    parser.add_argument("--height", type=int, default=1)
    parser.add_argument('--mode', type=str, nargs='*', default=[], choices=['gcnnorm', 'justmp'], help='Mode (default: [])')
    parser.add_argument("--kfold", type=int, default=None)
    parser.add_argument("--reps", type=int, default=1)
    parser.add_argument('--dataset', type=str, default='ZINC/subset', help='Dataset (default: generated)')
    # parser.add_argument('--output_file', type=str, default=None)
    args = parser.parse_args()

    main(**vars(args))
