# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function
from argparse import ArgumentParser
import os
import json
import matplotlib.pyplot as plt

import sys

sys.path.append("enter your path")

import numpy as np
from tqdm import tqdm

import torch
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn import CrossEntropyLoss, MSELoss

from datetime import datetime

from multiquery_randomized_smoothing.src.models.two_query_arch import TWO_QUERY_ARCH
from multiquery_randomized_smoothing.src.models.single_query_arch import SINGLE_QUERY_ARCH
from multiquery_randomized_smoothing.src.train_utils import (get_save_directory_path,
                                                            init_logfile,
                                                             log,
                                                             set_seed,
                                                             get_mask_shape,
                                                             get_image_size)
from multiquery_randomized_smoothing.src.dataset_utils import get_dataset

from torchsummary import summary

import matplotlib
matplotlib.use('Agg')  # Must be before importing matplotlib.pyplot or pylab!

######################################################################

def _count_arr(arr, length):
    counts = np.zeros(length, dtype=int)
    for idx in arr:
        counts[idx] += 1
    return counts

def train(epoch, args, saved_dir, train_loader, model):

    model.train()
    for batch_idx, (data, targets) in enumerate(train_loader):
        # print("batch_idx {}".format(batch_idx))
        data, targets = data.to(args.device), targets.to(args.device)

        # print(data.size())

        logging_trackers = {
            'epoch': epoch,
            'batch_idx': batch_idx,
            'mode': 'train',
            'saved_dir': saved_dir,
            'sigma_log_file': train_sigma_log_file,
        }
        if args.num_queries == 2:
            logging_trackers['budget_query_split_log_file'] = budget_query_split_log_file


        if args.mask_recon > 0 :
            output_pred = model(data, logging_trackers)
            mask_outputs_new = model.second_query_mask_model(data)
            mask_out = mask_outputs_new[:,0:1,:,:]
            recon = mask_outputs_new[:,1:4,:,:]
        else:
            output_pred = model(data, logging_trackers)

            
        # compute loss
        optimizer.zero_grad()
        if len(t_params) > 0:
            t_optimizer.zero_grad()
            if args.budget_split == "learnt":
                fqbf_optimizer.zero_grad()
        total_loss = criterion(output_pred, targets)



        # reconstruction loss for mask
        if args.mask_recon > 0:
            total_loss += args.mask_recon * (MSELoss().to(args.device))(recon, data)


        log(loss_file, "{}".format(total_loss))
            
        total_loss.backward()

        # optim step
        optimizer.step()
        if len(t_params) > 0:
            t_optimizer.step()
            if args.budget_split == "learnt":
                fqbf_optimizer.step()

        # clamp mask and budget frac values between 0 and 1 (PGD step)
        with torch.no_grad():
            if args.first_query_with_mask:
                model.first_query_mask.clamp_(min=0., max=1.)
            if args.num_queries == 2:
                if args.budget_split == "learnt":
                    # if model.first_query_budget_frac < 0 or model.first_query_budget_frac > 1:
                    model.first_query_budget_frac.clamp_(min=0.01, max=0.99)

        # if batch_idx % 10 == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader), total_loss.item()))

def test(epoch, args, saved_dir, test_loader, model):
    
    model.eval()
    with torch.no_grad():
        correct = 0
        for batch_idx, (data, targets) in enumerate(test_loader):
            data, targets = data.to(args.device), targets.to(args.device)

            logging_trackers = {
                'epoch': epoch,
                'batch_idx': batch_idx,
                'mode': 'test',
                'saved_dir': saved_dir,
                'sigma_log_file': test_sigma_log_file,
            }
            if args.num_queries == 2:
                logging_trackers['budget_query_split_log_file'] = budget_query_split_log_file

            # not using mask penalty currently
            # if args.mask_output == "penalty":
            #     output_pred, _ = model(data, logging_trackers)
            # else:
            
            # create n (here n=20) copies of each example
            n = 20
            repeated_batch = data[:, None, :, :, :].repeat((1, n, 1, 1, 1)).flatten(0,1)
            outputs = model(repeated_batch, logging_trackers)
            all_predictions = outputs.argmax(axis=1)
            all_predictions = all_predictions.reshape(args.test_batch_size, n).cpu().numpy()
            
            counts = np.apply_along_axis(np.bincount, axis=1, arr=all_predictions, minlength=10)
            predictions = np.argmax(counts, axis=1)
            
            correct += np.equal(predictions, targets.cpu().numpy()).sum()

        # save a bunch of metrics
        test_acc = 100. * (correct / len(test_loader.dataset))
        log(logfilename, "{} \t {:.3}".format(epoch, test_acc))
        # print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset), test_acc))
    return test_acc

######################################################################

# process images
def convert_image_np(tens):
    return tens.data.cpu().numpy().transpose((1, 2, 0))

# def convert_image_np(inp):
#     """Convert a Tensor to numpy image."""
#     inp = inp.data.cpu().numpy().transpose((1, 2, 0))
#     mean = np.array([0.485, 0.456, 0.406])
#     std = np.array([0.229, 0.224, 0.225])
#     inp = std * inp + mean
#     inp = np.clip(inp, 0, 1)
#     return inp

def vizualize_images(saved_dir):
    with torch.no_grad():
        
        if args.num_queries == 1 and not args.first_query_with_mask:
            num_cols = 2
        elif args.num_queries == 1 and args.first_query_with_mask:
            num_cols = 3
        elif args.num_queries == 2:
            if not args.average_queries:
                num_cols = 4
            else:
                num_cols = 5
            
        # Plot the results side-by-side
        f, axarr = plt.subplots(1, num_cols)
        col_index = 0

        # original images
        original_images = torch.load(
            os.path.join(saved_dir, "original_images.pt"))
        original_images_grid = convert_image_np(
            torchvision.utils.make_grid(original_images[:100]))    
        axarr[col_index].imshow(original_images_grid)
        axarr[col_index].set_title('Original')
        col_index += 1

        # post first query images
        images_after_first_query = torch.load(
            os.path.join(saved_dir, "q_1_images.pt"))
        images_after_first_query_grid = convert_image_np(
            torchvision.utils.make_grid(images_after_first_query[:100]))
        axarr[col_index].imshow(images_after_first_query_grid)
        axarr[col_index].set_title('1q')
        col_index += 1
        
        if args.num_queries == 2:
            second_query_masks = torch.load(os.path.join(
                saved_dir, "post_clamped_second_query_masks.pt"))
            second_query_masks_grid = convert_image_np(
                torchvision.utils.make_grid(second_query_masks[:100]))            
            axarr[col_index].imshow(second_query_masks_grid)
            axarr[col_index].set_title('2q masks')
            col_index += 1
        
            if args.average_queries:
                q_2_images_before_averaging = torch.load(
                    os.path.join(saved_dir, "q_2_images_before_averaging.pt"))
                q_2_images_before_averaging_grid = convert_image_np(
                    torchvision.utils.make_grid(q_2_images_before_averaging[:100]))
                axarr[col_index].imshow(q_2_images_before_averaging_grid)
                axarr[col_index].set_title('2q bef avg')
                col_index += 1
                
                q_2_images_after_averaging = torch.load(
                    os.path.join(saved_dir, "q_2_images_after_averaging.pt"))
                q_2_images_after_averaging_grid = convert_image_np(
                    torchvision.utils.make_grid(q_2_images_after_averaging[:100]))
                axarr[col_index].imshow(q_2_images_after_averaging_grid)
                axarr[col_index].set_title('2q aft avg')
                col_index += 1
            else:
                q_2_images = torch.load(os.path.join(saved_dir, "q_2_images.pt"))
                q_2_images_grid = convert_image_np(
                    torchvision.utils.make_grid(q_2_images[:100]))
                axarr[col_index].imshow(q_2_images_grid)
                axarr[col_index].set_title('2q')
                col_index += 1

        if args.first_query_with_mask:
            first_query_mask_tensor = torch.load(
                os.path.join(saved_dir, "first_query_mask.pt"))
            first_query_mask_np = first_query_mask_tensor.data.cpu().numpy()
            img = axarr[col_index].pcolormesh(first_query_mask_np, vmin=0., vmax=1.)
            axarr[col_index].invert_yaxis()
            axarr[col_index].set_aspect("equal")
            axarr[col_index].set_title('mask (1st)')
            f.colorbar(img, ax=axarr[col_index], label="weight", orientation="horizontal", ticks=[
                0., 0.25, 0.5, 0.75, 1.])

        # plt.title(saved_dir.split()[-1])
        plt.tight_layout()
        f.savefig(os.path.join(saved_dir, "first_batch.png"))
        plt.close(f)

if __name__ == "__main__":
    argparser = ArgumentParser()

    argparser.add_argument("--seed", type=int, default=42, help="random seed")
    argparser.add_argument("--device", type=str, default="cuda")
    argparser.add_argument("--epochs", type=int, default=100)
    argparser.add_argument("--dataset", type=str, default=None)
    argparser.add_argument("--num_classes", type=int, default=None)
    argparser.add_argument("--dataset_path", type=str, default=None)

    # dataset transformation params
    argparser.add_argument("--pad_size", type=int, help="amount of padding on single side of the image", default=0)
    argparser.add_argument("--num_image_locations", type=str, help="1, 2, 4, 8  or random number of positions the center image is placed in", default=None)
    argparser.add_argument("--background", type=str, help="'black', 'nature_5bg', 'nature_20kbg'; background type on which an image is padded", default=None)

    # some train params
    argparser.add_argument("--base_classifier", type=str, default=None)
    argparser.add_argument("--train_batch_size", type=int, default=None)
    argparser.add_argument("--test_batch_size", type=int, default=400)

    # optim params for all parameters apart from transformation
    argparser.add_argument("--optim_lr", type=float, default=None)
    argparser.add_argument("--weight_decay", type=float, default=None)
    argparser.add_argument("--momentum", type=float, default=None)
    argparser.add_argument("--step_size", type=int, default=None)
    argparser.add_argument("--gamma", type=float, default=None)

    # transformation optim params
    argparser.add_argument("--t_optim_lr", type=float, default=None)
    argparser.add_argument("--t_weight_decay", type=float, default=None)
    argparser.add_argument("--t_momentum", type=float, default=None)
    argparser.add_argument("--t_step_size", type=int, default=None)
    argparser.add_argument("--t_gamma", type=float, default=None)

    # DP and model arch related args
    argparser.add_argument("--num_queries", type=int,
                           default=None, help="number of queries: 1 or 2")
    argparser.add_argument("--budget_split", type=str,
                           default=None, help="fixed or learnt")
    argparser.add_argument("--first_query_budget_frac", type=float,
                           default=None, help="used when budget split is fixed")

    argparser.add_argument("--first_query_with_mask", action='store_true',
                           help="learn an average mask in the first query")
    argparser.add_argument("--second_query_mask_model", type=str,
                           help="architecture used to output per-input mask in the second query", default="None")
    argparser.add_argument("--mask_output", type=str, 
                           help="whether to apply mask penalty or use a sigmoid layer at end of mask model 'penalty' or 'sigmoid'")
    argparser.add_argument("--average_queries", action='store_true',
                           help="whether to average noisy images across queries")
    argparser.add_argument("--averaging_style", type=str, default=None,
                           help="use old or new averaging")

    argparser.add_argument("--mask_init", type=str, help="how to initialize the mask 'random' or 'identity'", default=None)
    argparser.add_argument("--wm_across_channels", type=str,
                           help="whether to learn a different mask per channel 'same' or 'different'", default=None)
    
    
    argparser.add_argument("--mask_recon", type=float, help="whether to add regulization for loss, mask will also learn to reconstruct the image, this is the corresponding coefficient", default=0)
    argparser.add_argument("--total_sigma", type=float, default=None, help="total sigma in our design (which is what we take vanilla sigma to be)") 


    argparser.add_argument("--run_description", type=str,
                           help="short compressed description to create a run directory", default=None)
    args = argparser.parse_args()

    print(args)
    # set the seeds
    set_seed(args.seed)

    save_dir = get_save_directory_path(args)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # save args
    with open(os.path.join(save_dir, "config.txt"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))
        f.write("\n")
        f.close()

    # DATASET TRANSFORMATION
    transform_params = {
        "pad_size": args.pad_size,
        "num_image_locations": args.num_image_locations,
        "background": args.background
    }

    # get training/test datasets and loaders
    train_dataset = get_dataset(dataset=args.dataset,
                                split="train",
                                path=os.path.join(
                                    args.dataset_path, args.dataset),
                                transform_params=transform_params)
    train_loader = DataLoader(train_dataset, shuffle=True,
                              batch_size=args.train_batch_size, num_workers=4, drop_last=True)
    
    test_dataset = get_dataset(dataset=args.dataset,
                               split='test',
                               path=os.path.join(
                                   args.dataset_path, args.dataset),
                               transform_params=transform_params)
    test_loader = DataLoader(test_dataset, shuffle=False,
                             batch_size=args.test_batch_size, num_workers=4, drop_last=True)

    # if args.num_queries == 1:
    #     query_budgets_fraction = {
    #         'first_q_frac': 1.0,
    #     }
    # elif args.num_queries == 2:
    #     # fixed budget split
    #     # query_budgets_fraction = {
    #     #     'first_q_frac': 0.5,
    #     #     'second_q_frac': 0.5
    #     # }

    #     # learnt budget split
    #     query_budgets_fraction = {
    #         'first_q_frac': nn.Parameter(torch.tensor(1/np.sqrt(2), device=args.device)),
    #         'second_q_frac': 0.5
    #     }

    if args.num_queries == 1:
        model = SINGLE_QUERY_ARCH(args).to(args.device)        
    else:
        model = TWO_QUERY_ARCH(args).to(args.device)

    image_size = get_image_size(args)
    
    if args.num_queries == 2:
        second_query_mask_model_summary = summary(model.second_query_mask_model, (3, image_size, image_size)) 
        torch.save(second_query_mask_model_summary, os.path.join(save_dir, "second_query_mask_model_summary.pt"))    
    
    base_classifier_summary = summary(model.base_classifier, (3, image_size, image_size))
    torch.save(base_classifier_summary, os.path.join(save_dir, "base_classifier_summary.pt"))

    ######################################################################
    # Training the model:
    ######################################################################

    # all transformation parameters
    t_params = []
    if args.num_queries == 1 and args.first_query_with_mask:
        t_params.append(
            {'params': model.first_query_mask}
        )
        print('enter location 1')
        t_optimizer = optim.AdamW(t_params,
                                lr=args.t_optim_lr,
                                # momentum=args.t_momentum,
                                weight_decay=args.t_weight_decay)
        t_scheduler = StepLR(t_optimizer,
                            step_size=args.t_step_size,
                            gamma=args.t_gamma)
    elif args.num_queries == 2 and args.first_query_with_mask:
        t_params.append(
            {'params': model.first_query_mask}
        )
        t_params.append(
            {'params': model.second_query_mask_model.parameters()}
        )
        print('enter location 2')
        t_optimizer = optim.AdamW(t_params,
                                lr=args.t_optim_lr,
                                # momentum=args.t_momentum,
                                weight_decay=args.t_weight_decay)
        t_scheduler = StepLR(t_optimizer,
                            step_size=args.t_step_size,
                            gamma=args.t_gamma)        
    elif args.num_queries == 2 and not args.first_query_with_mask:
        t_params.append(
            {'params': model.second_query_mask_model.parameters()}
        )
        print('enter location 3')
        t_optimizer = optim.AdamW(t_params,
                                lr=args.t_optim_lr,
                                # momentum=args.t_momentum,
                                weight_decay=args.t_weight_decay)
        t_scheduler = StepLR(t_optimizer,
                            step_size=args.t_step_size,
                            gamma=args.t_gamma)
    else:
        pass
    
    # t_optimizer = optim.AdamW([model.mask])
    # t_scheduler = StepLR(t_optimizer,
    #                     step_size=args.t_step_size,
    #                     gamma=0.1)

    # base classifier parameters
    params = [
        {'params': model.base_classifier.parameters()}
    ]
    print('enter location 4')
    optimizer = optim.AdamW(params,
                          lr=args.optim_lr,
                        #   momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.step_size,
                       gamma=args.gamma)
    
    # if we are learning first query budget fraction
    if args.num_queries == 2:
        if args.budget_split == "learnt":
            params.append({
                'params': model.first_query_budget_frac
            })
            print('enter location 5')
            fqbf_optimizer = optim.AdamW([model.first_query_budget_frac],
                                       lr=0.0001,
                                       # momentum=0.9,
                                       weight_decay=0)
            fqbf_scheduler = StepLR(fqbf_optimizer,
                                    step_size=args.step_size,
                                    gamma=args.gamma)

    # create loss function
    criterion = CrossEntropyLoss().to(args.device)

    # create train log file
    logfilename = os.path.join(save_dir, 'train_log.txt')
    init_logfile(logfilename, "epoch \t test_acc")


    train_sigma_log_file = os.path.join(save_dir, 'train_sigma_log.txt')
    test_sigma_log_file = os.path.join(save_dir, 'test_sigma_log.txt')
    if args.num_queries == 1:
        init_logfile(train_sigma_log_file, 'sigma_1')
        init_logfile(test_sigma_log_file, 'sigma_1')
    elif args.num_queries == 2:
        init_logfile(train_sigma_log_file, 'sigma_1 \t sigma_2 \t total_sigma')
        init_logfile(test_sigma_log_file, 'sigma_1 \t sigma_2 \t total_sigma')

        budget_query_split_log_file = os.path.join(save_dir, 'budget_query_split_log.txt')
        init_logfile(budget_query_split_log_file, 'query_1 \t query_2')

    loss_file = os.path.join(save_dir, 'train_loss.txt')
    init_logfile(loss_file, 'total_loss')

    best_test_acc = 0.0
    for epoch in tqdm(range(1, args.epochs+1)):
        # all to store less data
        if epoch % 10 == 0 or epoch == 1:
            train_saved_dir = os.path.join(save_dir, "train", "epoch_"+str(epoch))
            test_saved_dir = os.path.join(save_dir, "test", "epoch_"+str(epoch))
            if not os.path.exists(train_saved_dir):
                os.makedirs(train_saved_dir)
            if not os.path.exists(test_saved_dir):
                os.makedirs(test_saved_dir)
        else:
            train_saved_dir = [] # not gonna use it
            test_saved_dir = [] # not gonna use it
            
        # print("Time when training started {}", datetime.now().strftime("%H:%M:%S"))
        train(epoch, args, train_saved_dir, train_loader, model)
        print("done training epoch {}!".format(epoch))
        # print("Time when training ended {}", datetime.now().strftime("%H:%M:%S"))
        curr_test_acc = test(epoch, args, test_saved_dir, test_loader, model)
        print("done evaluating epoch {}!".format(epoch))
        # print("Time when testing ended {}", datetime.now().strftime("%H:%M:%S"))
        
        # save the best model
        if curr_test_acc > best_test_acc:
            torch.save(model.state_dict(), os.path.join(save_dir, "model_sd.pt"))
        
        # Visualize and save train and test images
        # all to store less data
        if epoch == 1 or epoch % 10 == 0:
            vizualize_images(train_saved_dir)
            vizualize_images(test_saved_dir)
        
        scheduler.step()
        if len(t_params) > 0:
            t_scheduler.step()
            # if args.budget_split == "learnt":
            #     fqbf_scheduler.step():
