import math
import os
import time
from pathlib import Path
from pprint import pprint
from simu_dataset import Simulator
import numpy as np
import torch
import torch.nn.functional as F
from model import LogisticRegression
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from diff_metrics import AUMCalculator, DatasetWithIndex
from utils import (
    AverageMeter,
    set_seed,
    expected_calibration_error,
    parse_args,
)
from aggregate_votes import get_aggregated_votes

current = os.path.dirname(__file__)
data_dir = os.path.join(current, "data")


def train_step(
    args,
    summary_writer,
    metrics,
    aum_calculator,
    log_interval,
    batch_step,
    num_batches,
    batch,
    epoch,
    num_epochs,
    global_step,
    model,
    optimizer,
    device,
):
    start = time.time()
    model.train()
    with torch.enable_grad():
        optimizer.zero_grad()

        (
            input,
            target,
            ttarget,
            confu,
            diff,
            distrib,
            theo_distrib,
            sample_ids,
        ) = batch
        if args.architecture == "logreg":
            input = input.view(-1, 16 * 16 * 3)
        input = input.to(device)
        target = target.to(device)

        # Compute output
        output = model(input)
        loss = Loss(output, target)

        # Compute gradient and optimize
        loss.backward()
        optimizer.step()

        # Measure accuracy & record loss
        end = time.time()
        batch_size = target.size(0)
        _, pred = output.data.cpu().topk(1, dim=1)
        if len(target.shape) > 1:
            tmp = torch.argmax(target, dim=1)
        else:
            tmp = target
        error = (
            torch.ne(pred.squeeze(), tmp.cpu()).float().sum().item()
            / batch_size
        )

        metrics["error"].update(error, batch_size)
        metrics["loss"].update(loss.item(), batch_size)
        metrics["batch_time"].update(end - start)

        # Update AUM
        aum_calculator.update(output, tmp, sample_ids.tolist())

        # log to tensorboard

        summary_writer.add_scalar(
            "train/batch_time", metrics["batch_time"].val, global_step
        )


def eval_step(
    args,
    regime,
    metrics,
    log_interval,
    batch_step,
    num_batches,
    batch,
    epoch,
    num_epochs,
    model,
    device,
):
    start = time.time()
    model.eval()
    with torch.no_grad():
        (
            input,
            target,
            ttarget,
            confu,
            diff,
            distrib,
            theo_distrib,
            sample_ids,
        ) = batch
        if args.architecture == "logreg":
            input = input.view(-1, 3 * 16 * 16)
        input = input.to(device)
        target = target.to(device)
        ttarget = ttarget.to(device)
        theo_distrib = theo_distrib.to(device)

        # Compute output
        output = model(input)
        loss = F.cross_entropy(output, theo_distrib)

        # Measure accuracy & record loss
        end = time.time()
        batch_size = ttarget.size(0)
        _, pred = output.data.cpu().topk(1, dim=1)
        error = (
            torch.ne(pred.squeeze(), ttarget.cpu()).float().sum().item()
            / batch_size
        )

        metrics["error"].update(error, batch_size)
        metrics["loss"].update(loss.item(), batch_size)
        metrics["batch_time"].update(end - start)
        metrics["ECE"].update(
            expected_calibration_error(
                output.softmax(axis=1).data.cpu().numpy(),
                ttarget.cpu().numpy(),
            ),
            batch_size,
        )
        # log to console
        if (batch_step + 1) % log_interval == 0:
            results = "\t".join(
                [
                    regime,
                    f"Epoch: [{epoch}/{num_epochs}]",
                    f"Batch: [{batch_step}/{num_batches}]",
                    f'Time: {metrics["batch_time"].val:.3f} ({metrics["batch_time"].avg:.3f})',
                    f'Loss: {metrics["loss"].val:.3f} ({metrics["loss"].avg:.3f})',
                    f'Error: {metrics["error"].val:.3f} ({metrics["error"].avg:.3f})',
                ]
            )
            print(results)


def main(args):
    pprint(vars(args))

    # Setup experiment folder structure
    # Create output folder if it doesn't exist
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # save out args
    with open(os.path.join(args.output_dir, "args.txt"), "w+") as f:
        pprint(vars(args), f)

    # Set seeds
    set_seed(args.seed)

    # Load dataset
    # Data transforms
    train_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
    test_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    # Datasets
    train_sets, val_sets, test_sets, names, SEED = get_aggregated_votes(
        nw=args.nw, SEED=args.seed, method=args.method
    )
    for (train_set, val_set, test_set, name) in zip(
        train_sets, val_sets, test_sets, names
    ):
        type_target = {
            "MV": int,
            "soft": float,
            "DS": int,
            "DSrmSpam": int,
            "GLAD": float,
        }
        for method in args.method:
            print("######## Begin train on ", method)
            path_res = os.path.join(
                args.data_dir, f"{name}", f"{name}_{method}_seed_{SEED}.csv"
            )
            # Setup summary writer
            summary_writer = SummaryWriter(
                log_dir=os.path.join(
                    args.output_dir, f"{name}", method, "tb_logs"
                )
            )

            ttype = type_target[method]
            targets = np.loadtxt(path_res, delimiter=",").astype(ttype)
            global Loss
            Loss = torch.nn.CrossEntropyLoss()
            train_set_ = Simulator(
                set=train_set, targets=targets, transform=train_transforms
            )
            val_set_ = Simulator(set=val_set, transform=test_transforms)
            test_set_ = Simulator(
                set=test_set, targets=targets, transform=train_transforms
            )

            train_set_, val_set_, test_set_ = (
                DatasetWithIndex(train_set_),
                DatasetWithIndex(val_set_),
                DatasetWithIndex(test_set_),
            )

            train_loader = DataLoader(
                train_set_,
                batch_size=args.train_batch_size,
                shuffle=True,
                pin_memory=(torch.cuda.is_available()),
                num_workers=0,
            )
            val_loader = DataLoader(
                val_set_,
                batch_size=args.val_batch_size,
                shuffle=False,
                pin_memory=(torch.cuda.is_available()),
            )
            test_loader = DataLoader(
                test_set_,
                batch_size=args.val_batch_size,
                shuffle=False,
                pin_memory=(torch.cuda.is_available()),
            )

            # Load Model
            device = "cuda" if torch.cuda.is_available() else "cpu"
            if args.architecture == "logreg":
                model = LogisticRegression(16 * 16 * 3, 8)
            else:
                raise ValueError("Architecture not implemented yet")
            model = model.to(device)
            num_params = sum(
                x.numel() for x in model.parameters() if x.requires_grad
            )
            print(model)
            print(f"Number of parameters: {num_params}")

            # Create optimizer & lr scheduler
            parameters = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.SGD(
                parameters,
                lr=args.learning_rate,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
                nesterov=True if args.momentum > 0 else False,
            )
            milestones = [0.5 * args.num_epochs, 0.75 * args.num_epochs]
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, gamma=0.1
            )

            # Keep track of AUM
            aum_calculator = AUMCalculator(
                os.path.join(args.output_dir, f"{name}", f"{method}"),
                compressed=(not args.detailed_aum),
            )

            # Keep track of things
            global_step = 0
            best_error = math.inf
            num_batches = len(train_loader)

            print("Beginning training")
            val_metrics = {
                "loss": AverageMeter(),
                "error": AverageMeter(),
                "batch_time": AverageMeter(),
                "ECE": AverageMeter(),
            }
            num_batches = len(val_loader)
            for batch_step, batch in enumerate(val_loader):
                eval_step(
                    args,
                    "VAL",
                    val_metrics,
                    args.log_interval,
                    batch_step,
                    num_batches,
                    batch,
                    0,
                    args.num_epochs,
                    model,
                    device,
                )

            summary_writer.add_scalar(
                "val/error", val_metrics["error"].avg, global_step
            )
            summary_writer.add_scalar(
                "val/loss", val_metrics["loss"].avg, global_step
            )
            summary_writer.add_scalar(
                "val/batch_time",
                val_metrics["batch_time"].avg,
                global_step,
            )
            summary_writer.add_scalar(
                "val/ECE",
                val_metrics["ECE"].avg,
                global_step,
            )
            for epoch in range(args.num_epochs):
                train_metrics = {
                    "loss": AverageMeter(),
                    "error": AverageMeter(),
                    "batch_time": AverageMeter(),
                    "ECE": AverageMeter(),
                }
                for batch_step, batch in enumerate(train_loader):
                    train_step(
                        args,
                        summary_writer,
                        train_metrics,
                        aum_calculator,
                        args.log_interval,
                        batch_step,
                        num_batches,
                        batch,
                        epoch,
                        args.num_epochs,
                        global_step,
                        model,
                        optimizer,
                        device,
                    )
                    summary_writer.add_scalar(
                        "train/error", train_metrics["error"].avg, global_step
                    )
                    summary_writer.add_scalar(
                        "train/loss", train_metrics["loss"].avg, global_step
                    )

                    global_step += 1

                scheduler.step()

                val_metrics = {
                    "loss": AverageMeter(),
                    "error": AverageMeter(),
                    "batch_time": AverageMeter(),
                    "ECE": AverageMeter(),
                }
                num_batches = len(val_loader)
                for batch_step, batch in enumerate(val_loader):
                    eval_step(
                        args,
                        "VAL",
                        val_metrics,
                        args.log_interval,
                        batch_step,
                        num_batches,
                        batch,
                        epoch,
                        args.num_epochs,
                        model,
                        device,
                    )

                # log eval metrics to tensorboard
                summary_writer.add_scalar(
                    "val/error", val_metrics["error"].avg, global_step
                )
                summary_writer.add_scalar(
                    "val/loss", val_metrics["loss"].avg, global_step
                )
                summary_writer.add_scalar(
                    "val/batch_time",
                    val_metrics["batch_time"].avg,
                    global_step,
                )
                summary_writer.add_scalar(
                    "val/ECE",
                    val_metrics["ECE"].avg,
                    global_step,
                )

                # Save best model
                if val_metrics["error"].avg < best_error:
                    best_error = val_metrics["error"].avg

                    os.makedirs(
                        os.path.join(args.output_dir, f"{name}", f"{method}"),
                        exist_ok=True,
                    )
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            args.output_dir, f"{name}", f"{method}", "best.pt"
                        ),
                    )
                if epoch % args.log_interval == 0:
                    results = "\t".join(
                        [
                            "TRAIN",
                            f"Epoch: [{epoch} / {args.num_epochs}]",
                            f'Batch time: {train_metrics["batch_time"].avg:.3f}',
                            f'Train Loss: {train_metrics["loss"].avg:.3f}',
                            f'Train Error: {train_metrics["error"].avg:.3f}',
                            f'Val Loss: {val_metrics["loss"].avg:.3f}',
                            f'Val Error: {val_metrics["error"].avg:.3f}',
                        ]
                    )
                    print(results)

            # Finalize aum calculator
            aum_calculator.finalize()

            # Eval best model on on test set
            model.load_state_dict(
                torch.load(
                    os.path.join(
                        args.output_dir, f"{name}", f"{method}", "best.pt"
                    )
                )
            )
            test_metrics = {
                "loss": AverageMeter(),
                "error": AverageMeter(),
                "batch_time": AverageMeter(),
                "ECE": AverageMeter(),
            }
            num_batches = len(test_loader)
            for batch_step, batch in enumerate(test_loader):
                eval_step(
                    args,
                    "TEST",
                    test_metrics,
                    args.log_interval,
                    batch_step,
                    num_batches,
                    batch,
                    -1,
                    -1,
                    model,
                    device,
                )

            # log eval metrics to tensorboard
            summary_writer.add_scalar(
                "test/error", test_metrics["error"].avg, global_step
            )
            summary_writer.add_scalar(
                "test/loss", test_metrics["loss"].avg, global_step
            )
            summary_writer.add_scalar(
                "test/batch_time", test_metrics["batch_time"].avg, global_step
            )
            summary_writer.add_scalar(
                "test/ECE", test_metrics["ECE"].avg, global_step
            )
            # log test metrics to console
            results = "\t".join(
                [
                    "FINAL TEST RESULTS",
                    f'Loss: {test_metrics["loss"].avg:.3f}',
                    f'Error: {test_metrics["error"].avg:.3f}',
                ]
            )
            print(results)


if __name__ == "__main__":
    args = parse_args()
    main(args)
