import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

# utilities
import datetime
import random
import os
import time
from pathlib import Path

from tqdm import tqdm

# pretrained models
import torchvision.models
import Resnets3D.models.resnet

# our modules
from load_data import *
from datasets.frame_dataset import *
from datasets.raw_video_dataset import * 
from logger import Logger
from crnn_wrapper import CRNN_Wrapper
from options import get_args, prettyprint_args
from inference import evaluate
from utils import save_tensor_to_video
from datasets.data_utils import batch_mean_sub, batch_mean_add

from sklearn.metrics import accuracy_score
from apex import amp

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device", device)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(42)
random.seed(42)

resnet_block_sizes = {'resnet10': [1, 1, 1, 1], 'resnet18': [2, 2, 2, 2], 
    'resnet34': [3, 4, 6, 3], 'resnet50': [3, 4, 6, 3],
    'resnet101': [3, 4, 23, 3], 'resnet152': [3, 8, 36, 3],
    'resnet200': [3, 24, 36, 3]}


def train(data, model, n_epochs, args, metrics=["acc", "loss", "precision", "recall", "f1", "val_acc", "val_loss", "val_precision", "val_recall", "val_f1"]):

    time_str = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
    model_dir = os.path.join(args.base_path, 'models', args.dataset)
    checkpoints_dir = os.path.join(args.base_path, 'checkpoints', args.dataset)
    log_dir = os.path.join(args.base_path, 'train_logs', args.dataset)
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)
    Path(log_dir).mkdir(parents=True, exist_ok=True)
    model_path = os.path.join(model_dir, time_str + ".pth")
    log_path = os.path.join(log_dir, time_str + ".log")
    if os.path.isfile(model_path):
        print("***WARNING!*** The path you have specified is already taken. If this is not intentional, you may want to abort the program at this time and specify a different file path.")
    print("Saving model, log to:", model_path, ",", log_path)

    # phase 1: setup data
    train_loader, val_loader, train_size, val_size = build_dataloader(data, args)
    print("Train on {} samples, validate on {}".format(train_size, val_size))
    print("Training for {} epochs, batch size {} ({} training batches)".format(args.epochs, args.bs, len(train_loader)))

    # phase 2: setup loss and optimizer
    loss_fn = nn.CrossEntropyLoss()
    if args.optimizer.lower() == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, dampening=args.dampening, weight_decay=args.weight_decay)
    elif args.optimizer.lower() == "adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError("Only SGD and Adam are currently supported as optimizers.")

    if args.adversarial:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    # phase 3: create callbacks
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.lr_patience, verbose=True)
    logger = Logger(train_size, val_size, args, metrics=metrics)
    #wandb.run.summary['args'] = vars(args)
    best_val_loss = float('inf')
    #checkpoint_base = os.path.basename(args.model_save_path).split(".")[0]
    #checkpoint_path = os.path.join(checkpoints_dir, "{}_best.pth".format(time_str))
    # phase 4: training loop
    n_epochs_no_improve = 0
    model.train() # just in case
    start = time.time()
    for epoch in range(n_epochs):
        if n_epochs_no_improve >= args.early_stopping: 
            print("Failed to improve for {} epochs; stopping early".format(args.early_stopping))
            break
        print("begin epoch {}/{}".format(epoch+1, n_epochs))
        total_loss = 0.
        train_preds = []
        train_y = []
        for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
            X, y = data[0].to(device), data[1].to(device)
            assert X.size(1) == 3 # number of channels
            assert X.size(2) == args.max_frames
            assert (X.size(3), X.size(4)) == (args.load_height, args.load_width)
            if args.adversarial and random.random() < args.apply_prob:

                X = random_fgsm(model, X, y, loss_fn, optimizer, eps=args.pgd_eps, clamp=(0, 1)) 

            optimizer.zero_grad()

            out = model(X)
            loss = loss_fn(out, y)
            total_loss += loss.item()
            if args.adversarial:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            _, preds = torch.max(out.data, 1)
            train_preds.append(preds)
            train_y.append(y)
            if i + 1 == args.limit:
                print(i +  1, "batches finished; skipping rest of epoch")
                break

        train_preds = torch.cat(train_preds).cpu()
        train_y = torch.cat(train_y).cpu()
        val_loss, val_preds, val_y, _ = evaluate(val_loader, model, logger=logger)
        scheduler.step(val_loss)

        # phase 4.5: evaluate callbacks
        logger.end_epoch(total_loss, train_preds, train_y, val_loss, val_preds, val_y) 
        logger.report()
        if best_val_loss > logger.log['val_loss'][-1]:
            best_val_loss = logger.log['val_loss'][-1]
            n_epochs_no_improve = 0
            print("Validation loss improved")
        else:
            n_epochs_no_improve += 1
        if not args.lightweight and (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(checkpoints_dir, "{}_{}.pth".format(time_str, epoch))
            print("Saving checkpoint to", checkpoint_path)
            torch.save(model.state_dict(), checkpoint_path)

    end = time.time()
    print("Training time: {}".format(datetime.timedelta(seconds=end-start)))
    logger.log['time'] = end-start
    #wandb.run.summary['time'] = end-start
    logger.save(log_path)
    if not args.lightweight:
        torch.save(model.state_dict(), os.path.join(model_dir, time_str + ".pth"))
    return model

def random_fgsm(model, X, y, loss_fn, optimizer, eps=8/255, clamp=(0, 1)):
    # Step 1: perform one-step FGSM attack on randomly Lp-perturbed image
    # Step 2: optimize model over perturbed image
    noise = torch.Tensor(*X.size()).uniform_(-eps, eps).to(X.device)
    step_size = 1.25 * eps # this is a magic number that the faster AT guys found
    x_adv = (noise + X).detach().clone()
    x_adv = batch_mean_add(x_adv)
    x_adv = x_adv.clamp(*clamp)
    x_adv = batch_mean_sub(x_adv)
    prediction = model(x_adv.requires_grad_(True))
    loss = loss_fn(prediction, y)
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
    pert = step_size * torch.sign(x_adv.grad)
    noise += pert.data
    x_adv = noise + X
    x_adv = batch_mean_add(x_adv)
    x_adv = x_adv.clamp(*clamp)
    x_adv = batch_mean_sub(x_adv)
    return x_adv.detach()


def projected_gradient_descent(model, X, y, loss_fn, optimizer, num_steps=40, eps=8/255, step_size=0.01, norm='inf', clamp=(0, 1), y_target=None):
    x_adv = X.detach().clone().requires_grad_(True).to(X.device)
    num_channels = X.size(1)
    for i in tqdm(list(range(num_steps))):
        _x_adv = x_adv.detach().clone().requires_grad_(True)
        prediction = model(_x_adv)
        loss = loss_fn(prediction, y_target if y_target is not None else y)
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        with torch.no_grad():
            if norm == 'inf':
                gradients = (_x_adv.grad.sign() * step_size)
            else:
                # Note .view() assumes batched image data as 4D tensor
                gradients = _x_adv.grad * step_size / _x_adv.grad.view(_x_adv.shape[0], -1).norm(norm, dim=-1).view(-1, num_channels, 1, 1, 1)

            if y_target is not None:
                x_adv -= gradients
            else:
                x_adv += gradients
        # Project back into l_norm ball and correct range
        if norm == 'inf':
            x_adv = torch.max(torch.min(x_adv, X + eps), X - eps)
        else:
            delta = x_adv - X
            # Assume x and x_adv are batched tensors where the first dimension is
            # a batch dimension
            mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps

            scaling_factor = delta.view(delta.shape[0], -1).norm(norm, dim=1)
            scaling_factor[mask] = eps

            # .view() assumes batched images as a 4D Tensor
            delta *= eps / scaling_factor.view(-1, 1, 1, 1, 1)
            x_adv = X + delta
        x_adv = batch_mean_add(x_adv)
        x_adv = x_adv.clamp(*clamp)
        x_adv = batch_mean_sub(x_adv)
    return x_adv.detach()

def load_model(args):
    if args.load_pretrained:
        model_path = os.path.join(args.base_path, args.load_pretrained)
        pretrain = torch.load(model_path)
        model_class = getattr(Resnets3D.models.resnet, args.model_classname)
        try:
            model = model_class(sample_size=args.load_width, sample_duration=args.max_frames, num_classes=1039)

            print("Loading pretrained model from", model_path, "of type", model_class)
            model.load_state_dict(pretrain['state_dict'])
            model.fc = nn.Linear(model.fc.in_features, args.n_classes)
        except Exception:
            model = model_class(sample_size=args.load_width, num_classes=51 if args.dataset == 'hmdb51' else 101, sample_duration=args.max_frames)
            model.load_state_dict(pretrain)
        model.to(device)
    else:
        model_class = getattr(Resnets3D.models.resnet, args.model_classname)
        model = model_class(sample_size=args.load_width, sample_duration=args.max_frames, num_classes=args.n_classes)
        model.to(device)
    else:
        raise ValueError("Data mode '{}' not recognized; must be either 'image' or 'sequence'.".format(args.mode))
    return model


def main():
    args = get_args()
    print(prettyprint_args(args))
    data = load_data(args)
    model = load_model(args)
    model = train(data, model, args.epochs, args)


if __name__ == '__main__':
    main()
