from dataset import get_dataset
from model import get_model, NN
import torch
from functools import cached_property
from tqdm.autonotebook import tqdm
import torch.nn as nn
import torch.nn.functional as F
import os
import pickle
from convert_model import convert, split_sequential
from collections import OrderedDict
import random
import numpy as np
from torch.optim.lr_scheduler import StepLR
from itertools import product

default_args = {
    "epochs": 10,
    "batch_size": 64,
    "test_batch_size": 1000,
    "lr": 0.01,
    "gamma": 0.9,
}


class Experiment:
    def __init__(
        self,
        dataset_name,
        model_path=None,
        train_kwargs={},
        dataset_kwargs={},
        split_layer=None,
        model=None,
        embedder=None,
        **other_kwargs,
    ):
        self.dataset_name = dataset_name
        self.model_path = model_path
        self.train_kwargs = train_kwargs
        self.dataset_kwargs = dataset_kwargs
        # if not model and not model_path:
        #     raise ValueError("Must provide either model or model_path")
        self._model = model
        self._embedder = embedder
        self.model_path = model_path
        self.split_layer = split_layer
        for k, v in other_kwargs.items():
            setattr(self, k, v)

    @cached_property
    def train_dataset(self):
        return get_dataset(self.dataset_name, train=True, **self.dataset_kwargs)

    @cached_property
    def test_dataset(self):
        return get_dataset(self.dataset_name, train=False, **self.dataset_kwargs)

    @property
    def model(self):
        if self._model is None:
            self.get_model(device="cuda" if torch.cuda.is_available() else "cpu")
        return self._model

    @property
    def embedder(self):
        if self._embedder is None:
            self.get_model(device="cuda" if torch.cuda.is_available() else "cpu")
        return self._embedder

    def get_model(self, split_layer=None, device="cuda" if torch.cuda.is_available() else "cpu", root="."):
        def get_children(module):
            children = module.children()
            if not hasattr(children, "__iter__"):
                return [module]
            children = tuple(children)
            if len(children) == 0:
                return [module]
            if len(children) == 1:
                return get_children(children[0])
            flat_children = []
            for child in children:
                flat_children.extend(get_children(child))
            return flat_children

        if self.model_path == "alexnet":
            model = torch.hub.load("pytorch/vision:v0.10.0", "alexnet", verbose=False).to(device)
            layers = OrderedDict([(str(i), layer) for i, layer in enumerate(get_children(model))])
            layers = OrderedDict()
            i = 0
            for layer in get_children(model):
                layers[str(i)] = layer
                i += 1
                if isinstance(layer, nn.AdaptiveAvgPool2d):
                    layers[str(i)] = nn.Flatten()
                    i += 1

            self._model = NN(layers=layers, input_shape=(3, 224, 224)).to(device)
        else:
            self._model = torch.load(os.path.join(root, self.model_path)).to(device)
        if split_layer is not None:
            self.split_layer = split_layer
            self._embedder, self._model = split_sequential(self._model, split_layer)
            self._embedder.eval()
        self._model.eval()
        if split_layer is not None:
            return self._embedder, self._model
        return self._model

    @cached_property
    def converted_model(self):
        return convert(self.model)

    def train_model(self):
        train_exp_model(self.model, self.train_dataset, self.test_dataset, **self.train_kwargs)

    def save_model(self, filename):
        self.model_path = filename
        self.model.to("cpu")
        torch.save(self.model, filename)

    def save(self, filename):
        pickle.dump(self, open(filename, "wb"))

    def load(filename):
        return pickle.load(open(filename, "rb"))

    def train_accuracy(self):
        return self.train_dataset.get_accuracy(self.model)

    def test_accuracy(self):
        return self.test_dataset.get_accuracy(self.model)

    def __setstate__(self, state):
        self.__dict__.update(state)

    def __reduce__(self):
        # if self.model_path is None:
        #     raise ValueError("Cannot pickle Experiment without model_path")
        if self._model is not None:
            self.__dict__["_model"] = None
        if self._embedder is not None:
            self.__dict__["_embedder"] = None
        return (
            Experiment,
            (
                self.dataset_name,
                self.model_path,
                self.train_kwargs,
                self.dataset_kwargs,
                self.split_layer,
            ),
            self.__dict__,
        )


class ExperimentArray:
    ## TODO: Phase out
    def __init__(self):
        self.exps = dict()
        self.keys = []

    def add_exp(self, exp, key):
        self.keys.append(key)
        self.exps[key] = exp

    def save_models(self, filename):
        os.makedirs(filename, exist_ok=True)
        for key, exp in self.exps.items():
            exp.save_model(os.path.join(filename, f"{key}.pt"))

    def save_exps(self, filename):
        os.makedirs(filename, exist_ok=True)
        for key, exp in self.exps.items():
            exp.save(os.path.join(filename, f"{key}.pkl"))

    def train_models(self):
        for key, exp in tqdm(self.exps.items(), total=len(self)):
            exp.train_model()

    def save(self, filename):
        pickle.dump(self, open(filename, "wb"))

    def load(filename):
        return pickle.load(open(filename, "rb"))

    def __len__(self):
        return len(self.exps)

    def __iter__(self):
        return iter(self.exps.values())

    def __getitem__(self, key):
        return self.exps[key]

    def items(self):
        return self.exps.items()


def train(model, train_loader, criterion, optimizer, verbose=0):
    model.train()
    train_loss = 0
    pbar = tqdm(train_loader, total=len(train_loader), disable=verbose < 2, leave=False, desc="Batches")
    for data, target in pbar:
        data, target = data.to(model.device, model.dtype), target.to(model.device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.squeeze(), target.squeeze())
        train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1000)
        optimizer.step()

    return train_loss / len(train_loader.dataset)


def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(model.device), target.to(model.device)
            output = model(data)
            test_loss += criterion(output.squeeze(), target.squeeze()).sum().item()
    return test_loss / len(test_loader.dataset)


def train_exp_model(model, train_dataset, test_dataset, verbose=0, use_scheduler=False, **kwargs):
    args = default_args | kwargs  ## TODO: Just put default_args as default parameters

    train_loader = torch.utils.data.DataLoader(train_dataset, args["batch_size"], shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, args["batch_size"], shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = F.cross_entropy if train_dataset.task == "classification" else F.mse_loss

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args["lr"],
    )

    scheduler = StepLR(optimizer, step_size=1, gamma=args["gamma"]) if use_scheduler else None

    pbar = tqdm(range(1, args["epochs"] + 1), desc="Training", leave=False, disable=(verbose == 0))
    for epoch in pbar:
        train_loss = train(model, train_loader, criterion, optimizer, verbose=verbose)
        test_loss = test(model, test_loader, criterion)
        if scheduler is not None:
            scheduler.step()
        pbar.set_postfix_str(
            f"Train Loss: {train_loss:.2E} Test Loss: {test_loss:.2E}" + f"LR: {scheduler.get_last_lr()[0]:.2E}"
            if use_scheduler
            else ""
        )


def train_real_exps():
    print("\n\n ========= Training Real Experiments =========\n\n")

    exps = [
        ("mnist", "mnist_fc", {"epochs": 50, "lr": 0.1, "use_scheduler": True}),
        ("california_housing_reg", "california_housing_reg", {"epochs": 60, "lr": 0.001}),
        ("cifar10", "cifar10_cnn", {"epochs": 30, "lr": 0.01, "use_scheduler": True, "verbose": 2, "batch_size": 4}),
    ]
    os.makedirs("experiments", exist_ok=True)
    os.makedirs("models", exist_ok=True)

    for dataset_name, model_name, train_kwargs in exps:
        random.seed(1)
        np.random.seed(1)
        torch.manual_seed(1)

        print(f"Running experiment - Dataset: {dataset_name}, Model: {model_name}")

        model = get_model(model_name)
        model.to("cuda" if torch.cuda.is_available() else "cpu")

        train_kwargs["verbose"] = train_kwargs.get("verbose", 1)
        exp = Experiment(dataset_name, model=model, train_kwargs=train_kwargs)

        exp.train_model()

        if exp.test_dataset.task == "classification":
            print(f"Test Accuracy: {exp.test_accuracy()}")
        else:
            print(f"Test MSE: {exp.test_dataset.get_mse(exp.model)}")
            print(f"Test R2: {exp.test_dataset.get_rsquared(exp.model)}")

        exp.save_model(f"models/{model_name}.pt")

        exp.save(f"experiments/{model_name}.pkl")

        test_load = Experiment.load(f"experiments/{model_name}.pkl")

        if exp.test_dataset.task == "classification":
            print(f"Test Accuracy: {test_load.test_accuracy()}")
        else:
            print(f"Test MSE: {test_load.test_dataset.get_mse(test_load.model)}")
            print(f"Test R2: {test_load.test_dataset.get_rsquared(test_load.model)}")

    Experiment("imagenette", "alexnet", {}).save("experiments/alexnet.pkl")


def train_synthetic_exps():
    print("\n\n ========= Training Synthetic Experiments =========\n\n")

    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)

    ncenters = 3
    ntrials = 5
    base_widths = [4, 8, 16]
    nhiddens = range(1, 5)
    ds = range(2, 6)
    train_models = True
    exp_dir = "exp_0/"

    e = ExperimentArray()

    for i, (d, base_width, nhidden, trial) in enumerate(product(ds, base_widths, nhiddens, range(ntrials))):
        widths = [base_width] * nhidden
        e.add_exp(
            Experiment(
                dataset_name="blobs",
                dataset_kwargs={"d": d, "n": 100 * ncenters, "centers": ncenters, "random_state": trial},
                train_kwargs={"epochs": 20, "lr": 0.01},
                model=get_model("mlp", widths=[d] + widths + [ncenters]),
                width=base_width,
                nhidden=nhidden,
                trial=trial,
                dim=d,
            ),
            key=f"blobs_mlp_{i}",
        )

    if train_models:
        e.train_models()

    if exp_dir:
        os.makedirs("experiments/" + exp_dir, exist_ok=True)
        e.save_models("experiments/" + exp_dir)
        e.save("experiments/" + exp_dir + "exp_array.pkl")


if __name__ == "__main__":
    os.makedirs("experiments", exist_ok=True)
    train_real_exps()
    train_synthetic_exps()
