import argparse
import math
import os
from collections import defaultdict
from os.path import join

import higher
import numpy as np
import pandas as pd
import torch
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from torchvision import transforms

import data
import models
import utils
import augment


class DataIterator:
    def __init__(self, data_loader):
        self.loader = data_loader
        self.iter = iter(data_loader)

    def __call__(self, aug_func, device):
        try:
            x, y = next(self.iter)
        except StopIteration:
            data_iterator = iter(self.loader)
            x, y = next(data_iterator)

        x = x.to(device)
        y = y.to(device)
        x_all = torch.cat([x, aug_func(x)])
        y_all = torch.cat([y, torch.ones_like(y)])
        return x_all, y_all


def parse_args():
    parser = argparse.ArgumentParser()

    # Data
    parser.add_argument('--root', type=str, default=join(utils.ROOT, 'data'))
    parser.add_argument('--data', type=str, default='mvtec')
    parser.add_argument('--obj-type', type=str, default='carpet')
    parser.add_argument('--ano-type', type=str, default='synthetic')

    # Environment 1
    parser.add_argument('--load', type=str, default=join(utils.ROOT, 'out'))
    parser.add_argument('--out', type=str, default=join(utils.ROOT, 'out'))
    parser.add_argument('--seed', type=int, default=2023)
    parser.add_argument('--verbose', type=utils.str2bool, default=True)

    # Environment 2
    parser.add_argument('--threads', type=int, default=16)
    parser.add_argument('--cuda', default=True, type=utils.str2bool)
    parser.add_argument('--gpu', default=0, type=int)

    # Augmentation function
    parser.add_argument('--augment', type=str, default='cutdiff')
    parser.add_argument('--init-scale', type=float, default=0.01)
    parser.add_argument('--init-angle', type=float, default=135)

    # Technical parameters
    parser.add_argument('--val-size', type=int, default=256)
    parser.add_argument('--anom-ratio', type=float, default=0.5)

    # Training parameters
    parser.add_argument('--warm-start', type=int, default=None)
    parser.add_argument('--num-epochs', type=int, default=None)
    parser.add_argument('--num-updates', type=int, default=None)
    parser.add_argument('--batch-size', type=int, default=None)
    parser.add_argument('--test-epochs', type=int, default=10)  # Only for visualization

    # Ablation study
    parser.add_argument('--val-loss', type=str, default='mean')  # random, fixed, mean, mmd
    parser.add_argument('--aug-norm', type=utils.str2bool, default=True)
    parser.add_argument('--second-order', type=utils.str2bool, default=True)

    # Demonstrative examples
    parser.add_argument('--syn-scale', type=float, default=0.01)
    parser.add_argument('--syn-ratio', type=float, default=4.0)

    return parser.parse_args()


def to_random_augment(aug_name, device):
    if aug_name in ['cutout', 'cutpaste', 'cutdiff', 'rotate']:
        scale = 10 ** np.random.uniform(-4, 0)
        limit = math.log10(math.sqrt(scale)) / 2
        ratio = 10 ** np.random.uniform(limit, -limit)
        angle = np.random.uniform(0, 360)
        return augment.to_aug_function(
            aug_name, scale=scale, ratio=ratio, angle=angle,
        ).to(device)
    else:
        raise ValueError(aug_name)


def to_accuracy(logits, labels):
    with torch.no_grad():
        return (torch.argmax(logits, dim=1) == labels).float().mean().item()


def train(args, evaluator: models.Evaluator):
    obj_type = args.obj_type
    ano_type = args.ano_type
    num_epochs = args.num_epochs
    device = args.device

    model_path = join(args.out, 'models', f'{obj_type}-{ano_type}.tch')
    log_path = join(args.out, 'logs', f'{obj_type}-{ano_type}.tsv')
    img_path = join(args.out, 'images')

    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    os.makedirs(os.path.dirname(log_path), exist_ok=True)

    train_transform = transforms.ColorJitter(
        brightness=0.1,
        contrast=0.1,
        saturation=0.1,
        hue=0.1
    )
    trn_data = data.load_data(args.root, args.data, obj_type, ano_type,
                              train_transform, mode='train',
                              syn_args=args.syn_args)
    drop_last = len(trn_data) > args.batch_size
    data_loader = DataLoader(trn_data, args.batch_size, drop_last=drop_last,
                             shuffle=True)
    data_iterator = DataIterator(data_loader)

    model = models.ProjectionNet([512, 128], num_classes=2).to(device)
    det_opt = optim.SGD(
        model.parameters(), lr=0.03, momentum=0.9, weight_decay=0.00003
    )
    trn_loss_fn = torch.nn.CrossEntropyLoss()
    normalize = data.Normalize(args.data)

    is_aug_learning = args.val_loss not in ['fixed', 'random']
    is_first_order = not args.second_order or not is_aug_learning

    if is_aug_learning:
        if args.augment == 'cutdiff':
            aug_func = augment.CutDiff(args.init_scale, requires_grad=True)
        elif args.augment == 'rotate':
            aug_func = augment.Rotate(args.init_angle, requires_grad=True)
        else:
            raise ValueError()
        aug_func = aug_func.to(device)
        aug_opt = optim.SGD(aug_func.parameters(), lr=1e-2)
        schedular = CosineAnnealingWarmRestarts(aug_opt, num_epochs)
    else:
        aug_func = to_random_augment(args.augment, device)
        aug_opt = None
        schedular = None

    log_dict = defaultdict(lambda: [])

    with torch.no_grad():  # Visualize the initial model
        model.eval()
        evaluator.run_model(model, aug_func, args.val_size, args.anom_ratio)
        val_loss = evaluator.to_validation_loss(args.val_loss, args.aug_norm)
        evaluator.visualize_images(aug_func, path_out=join(
            img_path, f'{obj_type}-{ano_type}', '0.png'
        ))

    if is_aug_learning:
        for epoch in range(args.warm_start):
            model.train()
            x, y = data_iterator(aug_func, device)
            _, logits = model(normalize(x))
            trn_loss = trn_loss_fn(logits, y)
            det_opt.zero_grad()
            trn_loss.backward()
            det_opt.step()

    if args.verbose:
        print('  epoch\ttr_loss\tva_loss\tscr_var', end='')
        for i in range(len(aug_func.get_parameters())):
            print(f"\t{f'a{i}':>7}", end='')
        print()

    trn_loss, trn_acc, best_loss = None, None, np.inf
    for epoch in range(num_epochs):
        if args.val_loss == 'random':
            aug_func = to_random_augment(args.augment, device)

        with higher.innerloop_ctx(model, det_opt) as (f_model, diff_opt):
            f_model.train()
            for _ in range(args.num_updates):
                x, y = data_iterator(aug_func, device)
                _, logits = f_model(normalize(x))
                trn_loss = trn_loss_fn(logits, y)
                diff_opt.step(trn_loss)
                trn_acc = to_accuracy(logits, y)

            if is_first_order:
                model.load_state_dict(f_model.state_dict())
                model.requires_grad_(False)
                f_model = model

            if is_aug_learning:
                f_model.eval()
                evaluator.run_model(
                    f_model, aug_func, args.val_size, args.anom_ratio
                )
                val_loss = evaluator.to_validation_loss(
                    args.val_loss, args.aug_norm
                )

                aug_opt.zero_grad()
                val_loss.backward()
                aug_opt.step()
                if schedular is not None:
                    schedular.step(epoch)

            if is_first_order:
                model.requires_grad_(True)
            else:
                model.load_state_dict(f_model.state_dict())

        if trn_loss + val_loss < best_loss:
            torch.save(model.state_dict(), model_path)
            best_loss = (trn_loss + val_loss).item()

        aug_param = aug_func.get_parameters()
        with torch.no_grad():
            val_dist = evaluator.to_variance()

        log_dict['trn_acc'].append(trn_acc)
        log_dict['trn_loss'].append(trn_loss.item())
        log_dict['val_loss'].append(val_loss.item())
        log_dict['val_dist'].append(val_dist)
        for i, e in enumerate(aug_param):
            log_dict[f'aug_p{i}'].append(e)

        if args.verbose:
            print(f'{epoch + 1:7d}\t'
                  f'{trn_loss:7.4f}\t'
                  f'{val_loss:7.4f}\t'
                  f'{val_dist:7.1f}', end='\t')
            print('\t'.join(f"{e:7.4f}" for e in aug_func.get_parameters()))

        if (epoch + 1) % args.test_epochs == 0:
            evaluator.visualize_images(aug_func, path_out=join(
                img_path, f'{obj_type}-{ano_type}', f'{epoch + 1}.png'
            ))

    trn_out = dict(
        trn_acc=trn_acc,
        trn_loss=trn_loss.item(),
        val_loss=val_loss.item(),
        val_dist=val_dist,
    )

    df = pd.DataFrame.from_dict(log_dict)
    df.index.name = 'epoch'
    df.to_csv(log_path, sep='\t')
    return aug_func, model, trn_out


def set_if_none(old_value, new_value):
    if old_value is None:
        return new_value
    else:
        return old_value


def main():
    args = parse_args()
    assert not (args.cuda and args.gpu is None)
    args.device = f'cuda:{args.gpu}' if args.cuda else 'cpu'
    args.syn_args = dict(scale=args.syn_scale, ratio=args.syn_ratio, angle=0)

    if args.data == 'mvtec':
        args.batch_size = set_if_none(args.batch_size, 32)
        args.warm_start = set_if_none(args.warm_start, 20)
        args.num_epochs = set_if_none(args.num_epochs, 500)
        args.num_updates = set_if_none(args.num_updates, 1)
    elif args.data == 'svhn':
        args.batch_size = set_if_none(args.batch_size, 256)
        args.warm_start = set_if_none(args.warm_start, 40)
        args.num_epochs = set_if_none(args.num_epochs, 100)
        args.num_updates = set_if_none(args.num_updates, 5)
    else:
        raise ValueError()

    if args.val_loss in ['random', 'fixed']:
        args.num_epochs *= args.num_updates
        args.num_updates = 1
        args.test_epochs = args.num_epochs

    utils.set_environment(args.seed, args.threads)
    os.makedirs(args.out, exist_ok=True)
    utils.save_json(vars(args), join(args.out, 'args.json'))

    evaluator = models.Evaluator(
        args.root, args.data, args.obj_type, args.ano_type, args.device,
        syn_args=args.syn_args
    )
    aug_func, model, trn_out = train(args, evaluator)

    if args.ano_type == 'all':
        ano_types = utils.get_anomaly_types(args.data, args.obj_type)
    else:
        ano_types = [args.ano_type]

    model.eval()
    for ano_type in ano_types:
        evaluator = models.Evaluator(
            args.root, args.data, args.obj_type, ano_type, args.device,
            syn_args=args.syn_args
        )
        with torch.no_grad():
            evaluator.run_model(model, aug_func, args.val_size, args.anom_ratio)
            evaluator.save_embeddings(join(
                args.out, 'embeddings', f'{args.obj_type}-{ano_type}'
            ))
        test_auc = evaluator.measure_auc()
        val_dist = evaluator.to_variance()
        aug_param = aug_func.get_parameters()
        df = pd.DataFrame([[
            args.obj_type,
            ano_type,
            trn_out['trn_acc'],
            trn_out['trn_loss'],
            trn_out['val_loss'],
            val_dist,
            test_auc,
            *aug_param
        ]])
        df.to_csv(join(args.out, 'out.tsv'), mode='a', sep='\t', index=False,
                  header=False)


if __name__ == '__main__':
    main()
