import math
import os
import time
from pathlib import Path
from pprint import pprint
import torch
import numpy as np
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from diff_metrics import AUMCalculator
import torch.nn as nn
from utils import (
    AverageMeter,
    set_seed,
    expected_calibration_error,
    parse_args,
)
from cifar10h_data import CrowdCIFAR
from aggregate_votes import get_aggregated_votes

current = os.path.dirname(__file__)
data_dir = os.path.join(current, "data")


def train_step(
    args,
    summary_writer,
    metrics,
    aum_calculator,
    log_interval,
    batch_step,
    num_batches,
    batch,
    epoch,
    num_epochs,
    global_step,
    model,
    optimizer,
    device,
):
    start = time.time()
    model.train()
    with torch.enable_grad():
        optimizer.zero_grad()

        (
            input,
            target,
            ttarget,
            sample_ids,
        ) = batch
        input = input.to(device)
        target = target.to(device)
        ttarget = ttarget.to(device)
        # Compute output
        output = model(input)
        loss = Loss(output, target)

        # Compute gradient and optimize
        loss.backward()
        optimizer.step()

        # Measure accuracy & record loss
        end = time.time()
        batch_size = target.size(0)
        _, pred = output.data.cpu().topk(1, dim=1)
        if len(target.shape) > 1:
            tmp = torch.argmax(target, dim=1)
        else:
            tmp = target
        error = (
            torch.ne(pred.squeeze(), tmp.cpu()).float().sum().item()
            / batch_size
        )

        metrics["error"].update(error, batch_size)
        metrics["loss"].update(loss.item(), batch_size)
        metrics["batch_time"].update(end - start)

        # Update AUM
        if epoch <= num_epochs / 2:
            aum_calculator.update(output, tmp, sample_ids.tolist())

        # log to tensorboard

        summary_writer.add_scalar(
            "train/batch_time", metrics["batch_time"].val, global_step
        )


def eval_step(
    args,
    regime,
    metrics,
    log_interval,
    batch_step,
    num_batches,
    batch,
    epoch,
    num_epochs,
    model,
    device,
):
    start = time.time()
    model.eval()
    with torch.no_grad():
        if len(batch) == 4:
            (
                input,
                target,
                ttarget,
                sample_ids,
            ) = batch
            ttarget = ttarget.to(device)
        else:
            (
                input,
                target,
                sample_ids,
            ) = batch
        input = input.to(device)
        target = target.to(device)

        # Compute output
        output = model(input)
        if len(batch) == 4:
            loss = Loss(output, ttarget)
        else:
            loss = Loss(output, target)
        # Measure accuracy & record loss
        end = time.time()
        batch_size = ttarget.size(0)
        cpu_out = output.data.cpu()
        _, pred = cpu_out.topk(1, dim=1)
        error = (
            torch.ne(pred.squeeze(), ttarget.cpu()).float().sum().item()
            / batch_size
        )

        metrics["error"].update(error, batch_size)
        metrics["loss"].update(loss.item(), batch_size)
        metrics["batch_time"].update(end - start)
        metrics["ECE"].update(
            expected_calibration_error(
                output.softmax(axis=1).data.cpu().numpy(),
                ttarget.cpu().numpy(),
            ),
            batch_size,
        )
        # log to console
        if (batch_step + 1) % log_interval == 0:
            results = "\t".join(
                [
                    regime,
                    f"Epoch: [{epoch}/{num_epochs}]",
                    f"Batch: [{batch_step}/{num_batches}]",
                    f'Time: {metrics["batch_time"].val:.3f} ({metrics["batch_time"].avg:.3f})',
                    f'Loss: {metrics["loss"].val:.3f} ({metrics["loss"].avg:.3f})',
                    f'Error: {metrics["error"].val:.3f} ({metrics["error"].avg:.3f})',
                ]
            )
            print(results)


def main(args):
    pprint(vars(args))

    # Setup experiment folder structure
    # Create output folder if it doesn't exist
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # save out args
    with open(os.path.join(args.output_dir, "args.txt"), "w+") as f:
        pprint(vars(args), f)

    # Set seeds
    set_seed(args.seed)

    # Load dataset
    # Data transforms
    train_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
    test_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    # Datasets
    train_sets, val_sets, test_sets, names, SEED = get_aggregated_votes(
        SEED=args.seed, method=args.method, train_transform=train_transforms
    )
    for (train_set, val_set, test_set, name) in zip(
        train_sets, val_sets, test_sets, names
    ):
        type_target = {
            "MV": int,
            "soft": float,
            "DS": int,
            "DSrmSpam": int,
            "GLAD": float,
        }
        for method in args.method:
            print("######## Begin train on ", method)
            path_res = os.path.join(
                args.data_dir, f"{name}", f"{name}_{method}_seed_{SEED}.csv"
            )
            # Setup summary writer
            summary_writer = SummaryWriter(
                log_dir=os.path.join(
                    args.output_dir, f"{name}", method, "tb_logs"
                )
            )

            ttype = type_target[method]
            targets = np.loadtxt(path_res, delimiter=",").astype(ttype)
            global Loss
            Loss = torch.nn.CrossEntropyLoss()
            train_set_ = CrowdCIFAR(
                set=train_set, targets=targets, transform=train_transforms
            )
            val_set_ = CrowdCIFAR(
                set=val_set,
                targets=val_set.c10h_c10_targets,
                transform=test_transforms,
            )
            test_set_ = CrowdCIFAR(
                set=test_set,
                targets=test_set.c10h_c10_targets,
                transform=train_transforms,
            )

            train_loader = DataLoader(
                train_set_,
                batch_size=args.train_batch_size,
                shuffle=True,
                pin_memory=(torch.cuda.is_available()),
                num_workers=0,
            )
            val_loader = DataLoader(
                val_set_,
                batch_size=args.val_batch_size,
                shuffle=False,
                pin_memory=(torch.cuda.is_available()),
            )
            test_loader = DataLoader(
                test_set_,
                batch_size=args.val_batch_size,
                shuffle=False,
                pin_memory=(torch.cuda.is_available()),
            )

            # Load Model
            device = "cuda" if torch.cuda.is_available() else "cpu"
            if args.architecture == "resnet18":
                model = resnet18(num_classes=10)  # cifar is too small
                model.conv1 = nn.Conv2d(
                    3,
                    64,
                    kernel_size=3,
                    stride=1,
                    padding=3,
                    bias=False,
                )
            else:
                raise ValueError("Architecture not implemented yet")
            model.maxpool = nn.Identity()  # avoid hard downsampling
            model = model.to(device)
            num_params = sum(
                x.numel() for x in model.parameters() if x.requires_grad
            )
            print(model)
            print(f"Number of parameters: {num_params}")
            print(len(train_loader), len(val_loader), len(test_loader))
            # Create optimizer & lr scheduler
            parameters = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.SGD(
                parameters,
                lr=args.learning_rate,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
                nesterov=True if args.momentum > 0 else False,
            )
            milestones = [0.5 * args.num_epochs, 0.75 * args.num_epochs]
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, gamma=0.1
            )
            # Keep track of AUM
            aum_calculator = AUMCalculator(
                os.path.join(args.output_dir, f"{name}", f"{method}"),
                compressed=(not args.detailed_aum),
            )

            # Keep track of things
            global_step = 0
            best_error = math.inf
            num_batches = len(train_loader)

            print("Beginning training")
            val_metrics = {
                "loss": AverageMeter(),
                "error": AverageMeter(),
                "batch_time": AverageMeter(),
                "ECE": AverageMeter(),
            }
            num_batches = len(val_loader)
            for batch_step, batch in tqdm(
                enumerate(val_loader), total=len(val_loader)
            ):
                eval_step(
                    args,
                    "VAL",
                    val_metrics,
                    args.log_interval,
                    batch_step,
                    num_batches,
                    batch,
                    0,
                    args.num_epochs,
                    model,
                    device,
                )

            summary_writer.add_scalar(
                "val/error", val_metrics["error"].avg, global_step
            )
            summary_writer.add_scalar(
                "val/loss", val_metrics["loss"].avg, global_step
            )
            summary_writer.add_scalar(
                "val/batch_time",
                val_metrics["batch_time"].avg,
                global_step,
            )
            summary_writer.add_scalar(
                "val/ECE",
                val_metrics["ECE"].avg,
                global_step,
            )
            for epoch in tqdm(range(args.num_epochs), total=args.num_epochs):
                train_metrics = {
                    "loss": AverageMeter(),
                    "error": AverageMeter(),
                    "batch_time": AverageMeter(),
                    "ECE": AverageMeter(),
                }
                for batch_step, batch in enumerate(train_loader):
                    train_step(
                        args,
                        summary_writer,
                        train_metrics,
                        aum_calculator,
                        args.log_interval,
                        batch_step,
                        num_batches,
                        batch,
                        epoch,
                        args.num_epochs,
                        global_step,
                        model,
                        optimizer,
                        device,
                    )
                    summary_writer.add_scalar(
                        "train/error", train_metrics["error"].avg, global_step
                    )
                    summary_writer.add_scalar(
                        "train/loss", train_metrics["loss"].avg, global_step
                    )

                    global_step += 1

                scheduler.step()

                if epoch % 20 == 0:
                    val_metrics = {
                        "loss": AverageMeter(),
                        "error": AverageMeter(),
                        "batch_time": AverageMeter(),
                        "ECE": AverageMeter(),
                    }
                    num_batches = len(val_loader)
                    for batch_step, batch in enumerate(val_loader):
                        eval_step(
                            args,
                            "VAL",
                            val_metrics,
                            args.log_interval,
                            batch_step,
                            num_batches,
                            batch,
                            epoch,
                            args.num_epochs,
                            model,
                            device,
                        )

                    # log eval metrics to tensorboard
                    summary_writer.add_scalar(
                        "val/error", val_metrics["error"].avg, global_step
                    )
                    summary_writer.add_scalar(
                        "val/loss", val_metrics["loss"].avg, global_step
                    )
                    summary_writer.add_scalar(
                        "val/batch_time",
                        val_metrics["batch_time"].avg,
                        global_step,
                    )
                    summary_writer.add_scalar(
                        "val/ECE",
                        val_metrics["ECE"].avg,
                        global_step,
                    )

                # Save best model
                if val_metrics["error"].avg < best_error:
                    best_error = val_metrics["error"].avg

                    os.makedirs(
                        os.path.join(args.output_dir, f"{name}", f"{method}"),
                        exist_ok=True,
                    )
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            args.output_dir, f"{name}", f"{method}", "best.pt"
                        ),
                    )
                if epoch % args.log_interval == 0:
                    results = "\t".join(
                        [
                            "TRAIN",
                            f"Epoch: [{epoch} / {args.num_epochs}]",
                            f'Batch time: {train_metrics["batch_time"].avg:.3f}',
                            f'Train Loss: {train_metrics["loss"].avg:.3f}',
                            f'Train Error: {train_metrics["error"].avg:.3f}',
                            f'Val Loss: {val_metrics["loss"].avg:.3f}',
                            f'Val Error: {val_metrics["error"].avg:.3f}',
                        ]
                    )
                    print(results)

            # Finalize aum calculator
            aum_calculator.finalize()

            # Eval best model on on test set
            model.load_state_dict(
                torch.load(
                    os.path.join(
                        args.output_dir, f"{name}", f"{method}", "best.pt"
                    )
                )
            )
            test_metrics = {
                "loss": AverageMeter(),
                "error": AverageMeter(),
                "batch_time": AverageMeter(),
                "ECE": AverageMeter(),
            }
            num_batches = len(test_loader)
            for batch_step, batch in enumerate(test_loader):
                eval_step(
                    args,
                    "TEST",
                    test_metrics,
                    args.log_interval,
                    batch_step,
                    num_batches,
                    batch,
                    -1,
                    -1,
                    model,
                    device,
                )

            # log eval metrics to tensorboard
            summary_writer.add_scalar(
                "test/error", test_metrics["error"].avg, global_step
            )
            summary_writer.add_scalar(
                "test/loss", test_metrics["loss"].avg, global_step
            )
            summary_writer.add_scalar(
                "test/batch_time", test_metrics["batch_time"].avg, global_step
            )
            summary_writer.add_scalar(
                "test/ECE", test_metrics["ECE"].avg, global_step
            )
            # log test metrics to console
            results = "\t".join(
                [
                    "FINAL TEST RESULTS",
                    f'Loss: {test_metrics["loss"].avg:.3f}',
                    f'Error: {test_metrics["error"].avg:.3f}',
                ]
            )
            print(results)


if __name__ == "__main__":
    args = parse_args()
    main(args)
