import argparse
import json
import os
import shutil
import signal
import traceback
from datetime import datetime

import torch
import numpy as np
from tensorboardX import SummaryWriter
from torchvision.transforms import RandomAffine

from datasets import get_dataset
from models import get_available_models, get_model
from optimizers import get_optimizer, get_available_optimizers

from datasets import get_available_datasets
from utils import get_accuracy, get_git_revision_short_hash


class DelayedKeyboardInterrupt:

    def __enter__(self):
        self.signal_received = False
        self.old_handler = signal.signal(signal.SIGINT, self.handler)

    def handler(self, sig, frame):
        self.signal_received = (sig, frame)
        print('SIGINT received. Delaying KeyboardInterrupt.')

    def __exit__(self, type, value, traceback):
        signal.signal(signal.SIGINT, self.old_handler)
        if self.signal_received:
            self.old_handler(*self.signal_received)

def build_parser():

    parser = argparse.ArgumentParser()

    parser.add_argument('--save-dir', type=str, required=True)
    parser.add_argument('--optimizer', type=str, default='sgd', choices=get_available_optimizers())

    parser.add_argument('--lr', type=float, default=1e-1)
    parser.add_argument('--scheduled-decay', action='store_true', default=False)

    parser.add_argument('--weight-decay', type=float, default=5e-4)
    parser.add_argument('--momentum', type=float, default=0.9)

    parser.add_argument('--dataset', type=str, default='cifar10', choices=get_available_datasets())
    parser.add_argument('--model', type=str, default='resnet18', choices=get_available_models())
    parser.add_argument('--append-date-to-title', default=False, action='store_true')

    parser.add_argument('--epochs', default=200, type=int)
    parser.add_argument('--additional-epochs', default=0, type=int)
    parser.add_argument('--save-interval', default=0, type=int)

    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--fix-seed-each-epoch', action='store_true', default=False)

    parser.add_argument('--no-augmentation', default=False, action='store_true')
    parser.add_argument('--no-shuffle', default=False, action='store_true')

    parser.add_argument('--random-gradient-subset-size-ratio', default=None, type=float)

    parser.add_argument('--second-lr', type=float, default=None)
    parser.add_argument('--second-repeat', type=int, default=None)

    #parser.add_argument('--first-repeat', type=int, default=None)
    #parser.add_argument('--start-repeat-at-epoch', type=int, default=None)
    parser.add_argument('--stop-repeat-at-epoch', type=int, default=None)

    parser.add_argument('--stop-on-model-distance', type=float, default=None)
    parser.add_argument('--sync-on-model-distance', type=float, default=None)
    parser.add_argument('--sync-interval', type=int, default=None)
    parser.add_argument('--be-efficient-after-model-distance', type=float, default=None)
    parser.add_argument('--force-sync-until-epoch', type=int, default=None)
    parser.add_argument('--stop-on-step', type=int, default=None)
    parser.add_argument('--save-model-every-step', default=False, action='store_true')
    parser.add_argument('--plot-model-distance', default=False, action='store_true')

    parser.add_argument('--starting-checkpoint', default=None, type=str)
    parser.add_argument('--starting-checkpoint-params-idx', default=None, type=int)
    parser.add_argument('--ignore-mismatch-between-checkpoint-and-args', default=False, action='store_true')

    parser.add_argument('--second-reg', type=float, default=0.)
    parser.add_argument('--adjust-perturbation-factor-by-lr', action='store_true', default=False)

    parser.add_argument('--augment-with-shearx', default=None, choices=['always', 'low-gradient'])
    return parser


class DummyScheduler(object):
    def step(self, *args, **kwargs):
        pass


class RandomAffineWithSwitch(RandomAffine):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.enabled = True

    def __call__(self, img):
        if self.enabled:
            return super().__call__(img)
        else:
            return img

def run(args):
    print("Running with arguments:")
    args_dict = {
        "commit_id": get_git_revision_short_hash()
    }
    for key in vars(args):
        if key == "default_function":
            continue
        args_dict[key] = getattr(args, key)
    for key in args_dict:
        print(key, ": ", args_dict[key])
    print("---")

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    title = os.path.basename(args.save_dir)
    if args.append_date_to_title:
        execution_date = datetime.now().strftime('%b%d_%H-%M-%S')
        args.save_dir = os.path.join(args.save_dir, execution_date)
    exp_dir = args.save_dir

    if os.path.exists(exp_dir) and os.path.exists(os.path.join(exp_dir, "exp_interrupted")):
        execution_date = datetime.now().strftime('%b%d_%H-%M-%S')
        old_exp_dir = "{}_interrupted_experiment_before_{}".format(exp_dir, execution_date)
        print("Moving interrupted experiment directory to {}".format(old_exp_dir))
        shutil.move(exp_dir, old_exp_dir)

    summary_writer = None
    os.makedirs(exp_dir)
    try:

        with open(os.path.join(exp_dir, "config.json"), "w") as f:
            json.dump(args_dict, f, indent=4, sort_keys=True, default=lambda x: x.__name__)

        summary_writer = SummaryWriter(exp_dir)
        additional_loader_keys = {}
        if args.augment_with_shearx:
            shear_augmentation = RandomAffine(0, shear=(-4, 4, 0, 0))
            additional_loader_keys["additional_augmentations"] = [shear_augmentation]
        loaders = get_dataset(args.dataset, batch_size=args.batch_size,
                              augment=not args.no_augmentation,
                              shuffle_train=not args.no_shuffle,
                              **additional_loader_keys)

        if torch.cuda.is_available():
            device = torch.device('cuda:0')
            print("CUDA Recognized")
        else:
            device = torch.device('cpu')

        lrs = [args.lr]
        if args.second_lr is not None:
            lrs.append(args.second_lr)

        model_names = []
        models = []
        optimizers = []
        schedulers = []

        if args.starting_checkpoint is not None:
            starting_checkpoint = torch.load(args.starting_checkpoint, map_location='cpu')
            last_epoch = starting_checkpoint['epoch']
            for key in args_dict:
                if key not in starting_checkpoint["config"]:
                    print(f"WARNING: {key} not found in checkpoint config. ignoring...")
                    continue
                if key in ["commit_id", "starting_checkpoint", "ignore_mismatch_between_checkpoint_and_args", "save_dir", "save_interval"]:
                    continue
                if args_dict[key] != starting_checkpoint["config"][key]:
                    err_msg = f"{key} is set to {args_dict[key]} but is equal to {starting_checkpoint['config'][key]} in checkpoint config."
                    if args.ignore_mismatch_between_checkpoint_and_args:
                        print(f"WARNING: {err_msg} ignoring since not strict...")
                    else:
                        raise Exception(err_msg)
        else:
            starting_checkpoint = None
            last_epoch = -1

        for idx, lr in enumerate(lrs):
            model = get_model(args.model, num_classes=loaders["num_classes"])
            if starting_checkpoint is not None:
                if args.starting_checkpoint_params_idx is not None:
                    model.load_state_dict(starting_checkpoint['model'][args.starting_checkpoint_params_idx])
                elif isinstance(starting_checkpoint['model'], list):
                    model.load_state_dict(starting_checkpoint['model'][idx])
                else:
                    model.load_state_dict(starting_checkpoint['model'])
            model = model.to(device)
            optimizer = get_optimizer(args.optimizer, model.parameters(), lr=lr, momentum=args.momentum,
                                      weight_decay=args.weight_decay)
            if starting_checkpoint is not None and args.starting_checkpoint_params_idx is None:
                if isinstance(starting_checkpoint['optimizer'], list):
                    optimizer.load_state_dict(starting_checkpoint['optimizer'][idx])
                else:
                    optimizer.load_state_dict(starting_checkpoint['optimizer'])

            if args.scheduled_decay:
                scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, gamma=0.1, milestones=[int(x / 200. * args.epochs) for x in [80, 120, 160]],
                                                                 last_epoch=last_epoch)
            else:
                scheduler = DummyScheduler()

            model_name = "initlr{}".format(lr)
            model_names.append(model_name)
            models.append(model)
            optimizers.append(optimizer)
            schedulers.append(scheduler)

        if starting_checkpoint is None:
            for model in models[1:]:
                for p1, p2 in zip(models[0].parameters(), model.parameters()):
                    p2.data.copy_(p1.data)

        repeats = [1] * len(lrs)
        if args.second_repeat is not None:
            repeats[1] = args.second_repeat

        regularizations = [0] * len(lrs)
        if args.second_reg != 0:
            regularizations[1] = args.second_reg

        criterion = torch.nn.CrossEntropyLoss().to(device)

        steps = 0

        if args.augment_with_shearx == 'low-gradient':
            gradient_norm_moving_average = 0
            last_epoch_gradient_norm = 0

        efficient_mode = False

        if starting_checkpoint is not None and 'random_state' in starting_checkpoint:
            torch.random.set_rng_state(starting_checkpoint['random_state']['torch_state'])
            np.random.set_state(starting_checkpoint['random_state']['numpy_state'])

        for epoch in range(last_epoch + 1 if last_epoch != -1 else 1, args.epochs + args.additional_epochs + 1):
            model.train()
            print(f"{title} - Epoch {epoch}")

            if args.fix_seed_each_epoch:
                torch.manual_seed(args.seed + epoch)
                np.random.seed(args.seed + epoch)

            if args.augment_with_shearx == 'low-gradient':
                print("Checking gradient norm for shearX")
                print("Previous epoch grad norm: {} Current grad norm: {}".format(last_epoch_gradient_norm, gradient_norm_moving_average))
                if gradient_norm_moving_average < last_epoch_gradient_norm:
                    print("Enabling ShearX")
                    shear_augmentation.shear = (-4., 4., 0., 0.,)
                else:
                    print("Disabling ShearX")
                    shear_augmentation.shear = (0., 0., 0., 0.)
                last_epoch_gradient_norm = gradient_norm_moving_average

            accuracies = [list() for _ in models]
            losses = [list() for _ in models]

            should_stop_training = False
            for batch_idx, (data_x, data_y) in enumerate(loaders["train_loader"]):
                data_x = data_x.to(device)
                data_y = data_y.to(device)

                for model_idx, (model, optimizer, repeat, regularization) in enumerate(zip(models, optimizers, repeats, regularizations)):
                    if args.stop_repeat_at_epoch is not None and epoch >= args.stop_repeat_at_epoch:
                        repeat = 1
                    for repeat_cntr in range(repeat):
                        optimizer.zero_grad()
                        model_y = model(data_x)
                        loss = criterion(model_y, data_y)
                        if regularization != 0:
                            grad_list = torch.autograd.grad(loss, model.parameters(), create_graph=True, retain_graph=True)
                            l_w = 0
                            for grad in grad_list:
                                l_w += torch.sum(torch.square(grad))
                            normalization_factor = 1
                            if args.adjust_perturbation_factor_by_lr:
                                normalization_factor *= args.lr / optimizers[0].param_groups[0]['lr']
                            loss = loss + (regularization / normalization_factor) * l_w
                        batch_accuracy = get_accuracy(model_y, data_y)
                        loss.backward()
                        optimizer.step()
                        if not efficient_mode and args.save_model_every_step:
                            torch.save({
                                'model': [model.state_dict() for model in models],
                                'config': args_dict,
                                'epoch': epoch,
                            }, os.path.join(exp_dir, "checkpoint{}_step{}_model{}repeat{}.pt".format(epoch,steps,model_idx, repeat_cntr)))

                    accuracies[model_idx].append(batch_accuracy.item())
                    losses[model_idx].append(loss.item())

                if args.save_model_every_step:
                    torch.save({
                        'model': [model.state_dict() for model in models],
                        'config': args_dict,
                        'epoch': epoch,
                    }, os.path.join(exp_dir, "checkpoint{}_step{}.pt".format(epoch, steps)))

                steps += 1

                if args.augment_with_shearx == 'low-gradient':
                    grad_norm = 0
                    for p in models[0].parameters():
                        grad_norm += torch.norm(p.grad) ** 2
                    beta = 0.9
                    gradient_norm_moving_average = gradient_norm_moving_average * beta + grad_norm * (1 - beta)

                if not efficient_mode:
                    synced = False
                    if args.sync_on_model_distance is not None or args.stop_on_model_distance is not None or args.plot_model_distance:
                        for model_name, model in zip(model_names[1:], models[1:]):
                            distance_norm = 0
                            for p1, p2 in zip(models[0].parameters(), model.parameters()):
                                distance_norm += torch.norm(p1 - p2) ** 2
                            summary_writer.add_scalar("distance_{}_{}".format(model_names[0], model_name), distance_norm, steps)
                            if args.stop_on_model_distance is not None and distance_norm > args.stop_on_model_distance:
                                should_stop_training = True
                            if args.sync_on_model_distance is not None and distance_norm <= args.sync_on_model_distance:
                                for p1, p2 in zip(models[0].parameters(), model.parameters()):
                                    p2.data.copy_(p1.data)
                                synced = True
                            if args.be_efficient_after_model_distance is not None and distance_norm > args.be_efficient_after_model_distance:
                                efficient_mode = True
                    if args.sync_interval is not None and not synced and steps % args.sync_interval == 0:
                        for model_name, model in zip(model_names[1:], models[1:]):
                            for p1, p2 in zip(models[0].parameters(), model.parameters()):
                                p2.data.copy_(p1.data)
                if args.stop_on_step is not None and steps >= args.stop_on_step:
                    should_stop_training = True
                if should_stop_training:
                    break
            if should_stop_training:
                break

            for model_name, model_accuracies, model_losses in zip(model_names, accuracies, losses):
                train_loss = np.mean(model_losses)
                train_accuracy = np.mean(model_accuracies)
                print("Model {} Train accuracy: {} Train loss: {}".format(model_name, train_accuracy, train_loss))
                summary_writer.add_scalar("train_loss_{}".format(model_name), train_loss, epoch)
                summary_writer.add_scalar("train_accuracy_{}".format(model_name), train_accuracy, epoch)

            train_loss = np.mean(losses[0])
            train_accuracy = np.mean(accuracies[0])
            summary_writer.add_scalar("train_loss", train_loss, epoch)
            summary_writer.add_scalar("train_accuracy", train_accuracy, epoch)
            summary_writer.add_scalar("steps", steps, epoch)

            for model_idx, (model_name, model) in enumerate(zip(model_names, models)):
                accuracies = []
                losses = []
                model.eval()
                for batch_idx, (data_x, data_y) in enumerate(loaders["test_loader"]):
                    data_x = data_x.to(device)
                    data_y = data_y.to(device)

                    model_y = model(data_x)
                    loss = criterion(model_y, data_y)
                    batch_accuracy = get_accuracy(model_y, data_y)

                    accuracies.append(batch_accuracy.item())
                    losses.append(loss.item())

                test_loss = np.mean(losses)
                test_accuracy = np.mean(accuracies)
                print("Model {} Test accuracy: {} Test loss: {}".format(model_name, test_accuracy, test_loss))
                summary_writer.add_scalar("test_loss_{}".format(model_name), test_loss, epoch)
                summary_writer.add_scalar("test_accuracy_{}".format(model_name), test_accuracy, epoch)

                if model_idx == 0:
                    summary_writer.add_scalar("test_loss", test_loss, epoch)
                    summary_writer.add_scalar("test_accuracy", test_accuracy, epoch)

            for scheduler in schedulers:
                scheduler.step()

            checkpoint_names = ["checkpoint_last.pt"]

            if args.save_interval > 0 and epoch % args.save_interval == 0:
                checkpoint_names.append(f'checkpoint{epoch}.pt')
            for name in checkpoint_names:
                torch.save({
                    'model': [model.state_dict() for model in models],
                    'config': args_dict,
                    'epoch': epoch,
                    'train_loss': train_loss,
                    'train_accuracy': train_accuracy,
                    'test_loss': test_loss,
                    'test_accuracy': test_accuracy,
                    'optimizer': [optimizer.state_dict() for optimizer in optimizers],
                    'random_state': {
                        'numpy_state': np.random.get_state(),
                        'torch_state': torch.random.get_rng_state()
                    }
                }, os.path.join(exp_dir, name))
    except (Exception, KeyboardInterrupt) as e:
        with DelayedKeyboardInterrupt():
            with open(os.path.join(exp_dir, "exp_interrupted"), "w") as f:
                f.write(str(e))
                f.write(traceback.format_exc())
        raise e
    finally:
        if summary_writer is not None:
            summary_writer.close()
    torch.save({
        'model': [model.state_dict() for model in models],
        'config': args_dict,
        'optimizer': [optimizer.state_dict() for optimizer in optimizers]
    }, os.path.join(exp_dir, 'checkpoint_final.pt'))



if __name__ == "__main__":
    parser = build_parser()

    args = parser.parse_args()

    run(args)



