import os
import json
import random
from typing import *
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from argparse import ArgumentParser
from tqdm import tqdm
import math

import unstructured_dataset
import structured_dataset
import task_dataset
import task_program
import output
import blackbox
from constants import *

from datetime import datetime
from time import time
import wandb

def data_to_device(data):
    for key in data:
        data[key] = data[key].to(DEVICE)

class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            config: dict,
            train: bool):
        self.dataset = task_dataset.TaskDataset(config, train)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset.__getitem__(index)

    @staticmethod
    def collate_fn(batch):
        data_dicts = [item[0] for item in batch]
        config = batch[0][1]
        collate_fns = {input[NAME]: structured_dataset.get_structured_dataset_static(
            input).collate_fn for input in config}
        imgs = {input[NAME]: collate_fns[input[NAME]](
            [data_dict[input[NAME]] for data_dict in data_dicts], input) for input in config}
        results = [item[2] for item in batch]
        return (imgs, results)


def train_test_loader(configuration, batch_size_train, batch_size_test):
    train_loader = torch.utils.data.DataLoader(
        Dataset(configuration, train=True),
        collate_fn=Dataset.collate_fn,
        batch_size=batch_size_train,
        shuffle=True
    )

    test_loader = torch.utils.data.DataLoader(
        Dataset(configuration, train=False),
        collate_fn=Dataset.collate_fn,
        batch_size=batch_size_test,
        shuffle=True
    )

    return train_loader, test_loader


class TaskNet(nn.Module):
    def __init__(
            self,
            unstructured_datasets: List[unstructured_dataset.UnstructuredDataset],
            task_config: dict,
            fn: Callable,
            output_mapping: output.OutputMapping,
            sample_count: int,
            batch_size_train: int,
            caching: bool):
        super(TaskNet, self).__init__()

        input_configs = task_config[INPUTS]

        self.config = input_configs
        self.unstructured_datasets = unstructured_datasets
        self.structured_datasets = [
            structured_dataset.get_structured_dataset_static(input) for input in input_configs]

        self.nets_dict = {}
        self.nets = self.get_nets_list()

        self.forward_fns = [partial(sd.forward, self.nets[i])
                            for i, sd in enumerate(self.structured_datasets)]
        input_mappings = tuple([sd.get_input_mapping(
            input_configs[i]) for i, sd in enumerate(self.structured_datasets)])
        loss_aggregator = task_config.get(LOSS_AGGREGATOR, ADD_MULT)
        self.eval_formula = \
            blackbox.BlackBoxFunction(function=fn,
                                      input_mappings=input_mappings,
                                      output_mapping=output_mapping,
                                      batch_size=batch_size_train,
                                      loss_aggregator=loss_aggregator,
                                      caching=caching,
                                      sample_count=sample_count)

    def get_nets_list(self):
        nets = []

        def add_net(ud_name, ud):
            if ud_name not in self.nets_dict:
                self.nets_dict[ud_name] = ud.net()
            nets.append(self.nets_dict[ud_name])
        for ud in self.unstructured_datasets:
            add_net(ud.name, ud)
        return nets

    def parameters(self):
        return [net.parameters() for net in self.nets_dict.values()]

    def task_test(self, args, x):
        return self.sampling.sample_test(args, data=x)

    def forward(self, x):
        keys = [key for key in x]
        distrs = [self.forward_fns[i](x[key]) for i, key in enumerate(keys)]
        inputs = [self.structured_datasets[i].distrs_to_input(distrs[i], x[keys[i]], input)
                  for i, input in enumerate(self.config)]
        return self.eval_formula(*inputs)

    def eval(self):
        for net in self.nets_dict.values():
            net.eval()

    def train(self):
        for net in self.nets_dict.values():
            net.train()

    def confusion_matrix(self):
        # Just print one confusion matrix for the first UD
        self.unstructured_datasets[0].confusion_matrix(self.nets[0])


class Trainer():
    def __init__(
            self,
            train_loader: torch.utils.data.DataLoader,
            test_loader: torch.utils.data.DataLoader,
            unstructured_datasets: List[unstructured_dataset.UnstructuredDataset],
            learning_rate: float,
            task_config: dict,
            fn: Callable,
            output_mapping: output.OutputMapping,
            sample_count: int,
            batch_size_train: int,
            caching: bool):
        self.network = TaskNet(unstructured_datasets=unstructured_datasets,
                               task_config=task_config,
                               fn=fn,
                               output_mapping=output_mapping,
                               sample_count=sample_count,
                               batch_size_train=batch_size_train,
                               caching=caching)
        self.output_mapping = output_mapping
        self.optimizers = [optim.Adam(
            net.parameters(), lr=learning_rate) for net in self.network.nets_dict.values()]
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.loss_fn = F.binary_cross_entropy

    def train_epoch(self, epoch):
        self.network.train()
        num_items = 0
        train_loss = 0
        total_correct = 0
        iter = tqdm(self.train_loader, total=len(self.train_loader))
        t_begin_total = time()
        for (i, (data, target)) in enumerate(iter):
            t_begin = time()
            data_to_device(data)
            (output_mapping, y_pred_sim, y_pred) = self.network(data)

            # Normalize label format
            batch_size = y_pred_sim.shape[0]
            norm_label, y = self.output_mapping.get_normalized_labels(
                y_pred_sim, target, output_mapping)

            if output_mapping:
                # Compute loss
                loss = self.loss_fn(y_pred_sim, norm_label)
                for optimizer in self.optimizers:
                    optimizer.zero_grad()
                loss.backward()
                for optimizer in self.optimizers:
                    optimizer.step()
                if not math.isnan(loss.item()):
                    train_loss += loss.item()

            # Collect index and compute accuracy
            if output_mapping:
                y_index = torch.argmax(y, dim=1)
                y_pred_index = torch.argmax(y_pred, dim=1)
                correct_count = torch.sum(torch.where(torch.sum(
                    y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size, device=DEVICE).bool())).item()
            else:
                correct_count = 0

            # Stats
            num_items += batch_size
            total_correct += correct_count
            perc = 100. * total_correct / num_items
            avg_loss = train_loss / (i + 1)
            epoch_time = time() - t_begin
            wandb.log({
            "train_time_per_epoch": epoch_time,
            "epoch": epoch,
            })
            # Prints
            iter.set_description(
                f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)")
        total_epoch_time = time() - t_begin_total
        wandb.log(
        {
        "epoch": epoch,
        "total_epoch_time": total_epoch_time,
        }
    )
        print(f"Total Epoch Time: {total_epoch_time}")

    def test_epoch(self, epoch):
        self.network.eval()
        num_items = 0
        test_loss = 0
        total_correct = 0
        with torch.no_grad():
            iter = tqdm(self.test_loader, total=len(self.test_loader))
            for i, (data, target) in enumerate(iter):
                data_to_device(data)
                (output_mapping, y_pred_sim, y_pred) = self.network(data)

                # Normalize label format
                batch_size = y_pred_sim.shape[0]

                norm_label, y = self.output_mapping.get_normalized_labels(
                    y_pred_sim, target, output_mapping)

                # Compute loss
                loss = self.loss_fn(y_pred_sim, norm_label)
                if not math.isnan(loss.item()):
                    test_loss += loss.item()

                # Collect index and compute accuracy
                if output_mapping:
                    y_index = torch.argmax(y, dim=1)
                    y_pred_index = torch.argmax(
                        y_pred, dim=1)
                    correct_count = torch.sum(torch.where(torch.sum(
                        y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size, device=DEVICE).bool())).item()
                else:
                    correct_count = 0

                # Stats
                num_items += batch_size
                total_correct += correct_count
                perc = 100. * total_correct / num_items
                avg_loss = test_loss / (i + 1)

                # Prints
                iter.set_description(
                    f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)")
        wandb.log(
          {
            "epoch": epoch,
            "test_accuracy": perc,
            "test_loss": test_loss
          }
        )
        # self.network.confusion_matrix()

    def train(self, n_epochs):
        # self.test_epoch(0)
        for epoch in range(1, n_epochs + 1):
            self.train_epoch(epoch)
            self.test_epoch(epoch)


if __name__ == "__main__":
    # Argument parser
    parser = ArgumentParser("neuro-symbolic-dataset")
    parser.add_argument("--n-epochs", type=int, default=100)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--n-samples", type=int, default=100)
    parser.add_argument("--configuration", type=str,
                        default="configuration.json")
    parser.add_argument("--caching", type=bool, default=True)
    parser.add_argument("--threaded", type=int, default=0)
    args = parser.parse_args()

    # environment init
    torch.multiprocessing.set_start_method('spawn')

    # Read json
    dir_path = os.path.dirname(os.path.realpath(__file__))
    configuration = json.load(
        open(os.path.join(dir_path, args.configuration)))

    # Parameters
    n_epochs = args.n_epochs
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # Dataloaders

    task_config = configuration["sum_4_mnist"]

    # Initialize the train and test loaders
    batch_size_train = task_config[BATCH_SIZE_TRAIN]
    batch_size_test = task_config[BATCH_SIZE_TEST]
    train_loader, test_loader = train_test_loader(
        task_config, batch_size_train, batch_size_test)
    config = {
    "sum_n": 4,
    "n_epochs": 100,
    "batch_size": BATCH_SIZE_TRAIN, 
    "seed": args.seed,
    "experiment_type": "ISED", 
    }
    timestamp = datetime.now()
    id = f'ised_sum10_{args.seed}_{timestamp.strftime("%Y-%m-%d %H-%M-%S")}'

    wandb.init(
    project="Baselines", config=config, id=id
    )

    wandb.define_metric("epoch")
    wandb.define_metric("total_epoch_time")
    wandb.define_metric("train_time_per_epoch", step_metric="epoch", summary="mean")
    wandb.define_metric("test_accuracy", step_metric="epoch", summary="max")
    wandb.define_metric("test_loss", step_metric="epoch", summary="min")
    # Set the output mapping
    om = output.get_output_mapping(task_config)
    # Create trainer and train
    py_func = task_config[PY_PROGRAM]
    learning_rate = task_config[LEARNING_RATE]
    fn = task_program.dispatcher[py_func]
    unstructured_datasets = [task_dataset.TaskDataset.get_unstructured_dataset(
        input, train=True) for input in task_config[INPUTS]]
    trainer = Trainer(train_loader=train_loader,
                        test_loader=test_loader,
                        unstructured_datasets=unstructured_datasets,
                        learning_rate=learning_rate,
                        task_config=task_config,
                        fn=fn,
                        output_mapping=om,
                        sample_count=args.n_samples,
                        batch_size_train=batch_size_train,
                        caching=args.caching)
    trainer.train(n_epochs)
