# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning any 🤗 Transformers model for image classification leveraging 🤗 Accelerate."""
import argparse
import json
import logging
import math
import os
from pathlib import Path

import torch
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)
from tqdm.auto import tqdm

import transformers
from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, SchedulerType, get_scheduler
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

import logging
import random
import numpy as np
import wandb
from collections import OrderedDict
from torch.profiler import profile, ProfilerActivity
from transformers.models.vit import linear_vit
import itertools

def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tune a Transformers model on an image classification dataset")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="cifar10",
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset)."
        ),
    )
    parser.add_argument(
        "--train_val_split",
        type=float,
        default=0.15,
        help="Percent to split off of train for validation",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        default="google/vit-base-patch16-224-in21k",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=32,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument("--wandb", action="store_true")

    parser.add_argument("--cal_var_freq", type=int, default=100)
    parser.add_argument("--eval_freq", type=int, default=200)
    parser.add_argument("--cal_var_m", type=int, default=2)
    parser.add_argument("--s", type=float, default=1.0)
    parser.add_argument("--act_var_tolerance", type=float, default=0.01)
    parser.add_argument("--weight_var_tolerance", type=float, default=0.01)
    parser.add_argument("--s_update_step", type=float, default=0.002)
    parser.add_argument("--weight_ratio_multiplier", type=float, default=0.95)
    
    args = parser.parse_args()

    if args.output_dir is not None:
        os.makedirs(os.path.join("log" ,args.output_dir), exist_ok=True)
        os.makedirs(os.path.join("log" ,args.output_dir, "profile"), exist_ok=True)

    return args

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def update_summary(step, train_metrics, eval_metrics):
    rowd = OrderedDict(step=step)
    rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
    rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
    wandb.log(rowd)

def eval(model, eval_dataloader, device):
    model.eval()
    total_loss = 0
    total_acc = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        loss = outputs.loss.mean()
        predictions = torch.argmax(logits, dim=-1)
        acc = (predictions == batch["labels"]).sum().item() / len(predictions)
        total_loss += loss.item()
        total_acc += acc
    eval_loss = total_loss / len(eval_dataloader)
    eval_acc = total_acc / len(eval_dataloader)
    model.train()
    return eval_loss, eval_acc

def main():
    args = parse_args()

    linear_vit.CAL_VAR_M = args.cal_var_m
    linear_vit.S = args.s
    linear_vit.ACT_VAR_TOLERANCE = args.act_var_tolerance
    linear_vit.WEIGHT_VAR_TOLERANCE = args.weight_var_tolerance
    linear_vit.S_UPDATE_STEP = args.s_update_step
    linear_vit.WEIGHT_RATIO_MULTIPLIER = args.weight_ratio_multiplier

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    # send_example_telemetry("run_image_classification_no_trainer", args)

    # Make one log on every process with the configuration for debugging.
    LOG_FORMAT = "[%(asctime)s] %(message)s"
    logging.basicConfig(
        format=LOG_FORMAT,
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
        filename=os.path.join("log", args.output_dir, f"{args.output_dir}.log"),
    )

    # If passed along, set the training seed now.
    if args.seed is not None:
        setup_seed(args.seed)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    PROJECT = "your_project"
    ENTITY = "your_entity"

    if args.wandb:
        wandb.init(name=args.log_name, project=PROJECT, config=args, entity=ENTITY)

    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    logging.info(f"Loading dataset {args.dataset_name}")
    dataset = load_dataset(args.dataset_name, task="image-classification", cache_dir="/mnt/.cache/datasets")

    # If we don't have a validation split, split off a percentage of train as validation.
    args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
    if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
        split = dataset["train"].train_test_split(args.train_val_split)
        dataset["train"] = split["train"]
        dataset["validation"] = split["test"]

    # Prepare label mappings.
    # We'll include these in the model's config to get human readable labels in the Inference API.
    labels = dataset["train"].features["labels"].names
    label2id = {label: str(i) for i, label in enumerate(labels)}
    id2label = {str(i): label for i, label in enumerate(labels)}

    # Load pretrained model and image processor
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    logging.info(f"Loading model {args.model_name_or_path} and feature extractor")
    config = ViTConfig.from_pretrained(
        args.model_name_or_path,
        num_labels=len(labels),
        i2label=id2label,
        label2id=label2id,
        finetuning_task="image-classification",
    )
    feature_extractor = ViTFeatureExtractor.from_pretrained(args.model_name_or_path)
    model = ViTForImageClassification.from_pretrained(
        args.model_name_or_path,
        config=config,
    )
    model.to(device)

    # Preprocessing the datasets

    # Define torchvision transforms to be applied to each image.
    normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )
    val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

    def preprocess_train(example_batch):
        """Apply _train_transforms across a batch."""
        example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
        return example_batch

    def preprocess_val(example_batch):
        """Apply _val_transforms across a batch."""
        example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
        return example_batch

    # Set the training transforms
    train_dataset = dataset["train"].with_transform(preprocess_train)
    # Set the validation transforms
    eval_dataset = dataset["validation"].with_transform(preprocess_val)

    # DataLoaders creation:
    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["labels"] for example in examples])
        return {"pixel_values": pixel_values, "labels": labels}

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.batch_size
    )
    eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size)

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = len(train_dataloader)
    max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=max_train_steps,
    )

    # # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    # num_update_steps_per_epoch = len(train_dataloader)
    # max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # # Afterwards we recalculate our number of training epochs
    # args.num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)



    # Train!
    total_batch_size = args.batch_size

    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {len(train_dataset)}")
    logging.info(f"  Num Epochs = {args.num_train_epochs}")
    logging.info(f"  Instantaneous batch size per device = {args.batch_size}")
    logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logging.info(f"  Total optimization steps = {max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(max_train_steps))
    completed_steps = 0
    starting_epoch = 0

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    pure_train_time_elapsed = 0
    train_time_elapsed = 0
    eval_time_elapsed = 0

    ratio_N = 0
    ratio_avg = 0

    total_loss = 0

    start.record()
    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss.mean()
            total_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            completed_steps += 1

            if completed_steps % args.cal_var_freq == 0:
                end.record()
                torch.cuda.synchronize()
                pure_train_time_elapsed += start.elapsed_time(end)

                linear_vit.test = True
                linear_vit.prepare(device)
                org_state = torch.get_rng_state().clone()                
                for i, batch in enumerate(itertools.islice(train_dataloader, args.cal_var_m)):
                    batch = {k: v.to(device) for k, v in batch.items()}
                    linear_vit.sample = False
                    setup_seed(i)
                    outputs = model(**batch)
                    loss = outputs.loss.mean()
                    loss.backward()
                    optimizer.zero_grad()
                    
                    linear_vit.sample = True
                    for _ in range(args.cal_var_m):
                        setup_seed(i)
                        outputs = model(**batch)
                        loss = outputs.loss.mean()
                        loss.backward()
                        optimizer.zero_grad()
                linear_vit.test = False

                linear_vit.cal_var()
                sgd_var, act_var = linear_vit.update_activation_ratio()
                weight_var = linear_vit.update_weight_ratio()
                linear_vit.reset_dict()

                torch.set_rng_state(org_state)

                S = linear_vit.S
                activation_ratio_schedule = linear_vit.activation_ratio_schedule
                weight_ratio_dict = linear_vit.weight_ratio_dict
                ratio = (1 + sum(activation_ratio_schedule) / len(activation_ratio_schedule) + sum([sum(weight_ratio_dict[i] for i in range(6 * j, 6 * j + 6)) / 6 * activation_ratio_schedule[j] for j in range(len(activation_ratio_schedule))]) / len(activation_ratio_schedule)) / 3 + args.cal_var_m / args.cal_var_freq + args.cal_var_m ** 2 * (1 + 2 * sum(activation_ratio_schedule) / len(activation_ratio_schedule)) / 3 / args.cal_var_freq
                ratio_avg = (ratio_avg * ratio_N + ratio) / (ratio_N + 1)
                logging.info(f"ratio: {ratio}, ratio_avg: {ratio_avg}")
                ratio_N += 1
                sample_metric = {"sgd_var": sgd_var, "act_var": act_var, "weight_var": weight_var, "S": S, "activation_ratio_first": activation_ratio_schedule[0], "activation_ratio_last": activation_ratio_schedule[-1], "weight_ratio[0]": weight_ratio_dict[0], "weight_ratio[1]": weight_ratio_dict[1], "weight_ratio[2]": weight_ratio_dict[2], "weight_ratio[3]": weight_ratio_dict[3], "weight_ratio[4]": weight_ratio_dict[4], "weight_ratio[5]": weight_ratio_dict[5], "ratio": ratio, "ratio_avg": ratio_avg}
                if args.wandb:
                    update_summary(completed_steps, sample_metric, {})

                end.record()
                torch.cuda.synchronize()
                train_time_elapsed += start.elapsed_time(end)
                start.record()

            if completed_steps % args.eval_freq == 0:
                end.record()
                torch.cuda.synchronize()
                train_time_elapsed += start.elapsed_time(end)
                
                start.record()
                eval_loss, eval_acc = eval(model, eval_dataloader, device)
                train_metric = {"loss": total_loss / args.eval_freq}
                eval_metric = {"loss": eval_loss, "acc": eval_acc}
                if args.wandb:
                    update_summary(completed_steps, train_metric, eval_metric)
                logging.info(f"\nEpoch {epoch} - Step {completed_steps} - Train loss: {total_loss / args.eval_freq} - Eval loss: {eval_loss} - Eval acc: {eval_acc}\n")
                total_loss = 0

                end.record()
                torch.cuda.synchronize()
                eval_time_elapsed += start.elapsed_time(end)
                start.record()
        end.record()
        torch.cuda.synchronize()
        train_time_elapsed += start.elapsed_time(end)

    logging.info(f"training finished, train time: {train_time_elapsed/1000/60} min, eval time: {eval_time_elapsed/1000/60} min, pure train time: {pure_train_time_elapsed/1000/60} min")   



if __name__ == "__main__":
    main()