
import argparse
import datetime
import math

import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json

from pathlib import Path

from timm.models import create_model

import engine

from datasets import build_dataset
from engine import evaluate
from cal_flops import calculate_flops
import models
import utils
from SVDLinear import SVDLinear

class Stdout_Font:
    Reset = '\033[0m'
    Bold = '\033[1m'
    Italic = '\033[3m'
    Underline = '\033[4m'
    
    Red = '\033[91m'
    Green = '\033[92m'
    Yellow = '\033[93m'
    Blue = '\033[94m'
    Purple = '\033[95m' 
    Cyan = '\033[96m'
    
    BG_Red = '\033[101m'
    BG_Green = '\033[102m'
    BG_Yellow = '\033[103m'
    BG_Blue = '\033[104m' 
    BG_Purple = '\033[105m'
    BG_Cyan = '\033[106m'

@torch.no_grad()
def eval_runtime(model, device, batch_size):
    model.to(device)
    model.eval()

    n = 224

    inputs = torch.rand(batch_size, 3, n, n).to(device)

    torch.cuda.synchronize()
    for i in range(20):
        output = model(inputs)
    torch.cuda.synchronize()

    max_k = 100
    tic = time.perf_counter()
    torch.cuda.synchronize()
    for i in range(max_k):
        output = model(inputs)
    torch.cuda.synchronize()
    toc = time.perf_counter()
    avg_time = ((toc - tic) * 1000) / (max_k * batch_size)
    print(Stdout_Font.Blue + 'Average time per image: {:.5f} (ms), '.format(avg_time) + Stdout_Font.Reset)

def generate_steps(a: int, b: int, steps: int = 12) -> list:
            step_size = (b - a) / (steps - 1)
            return [int(round(a + step_size * i)) for i in range(steps)]

def base_decompose(model, device, args):
    from vision_transformer_svd import Attention_SVD
    
    print(Stdout_Font.Yellow + "decomposing model by naive..." + Stdout_Font.Reset)
    print(">>>> attn_compress_ratio : ", args.attn_compress_ratio)
    # print(">>>> mlp_compress_ratio : ", args.mlp_compress_ratio)
    
    original_attn_params = 0.0
    decompose_attn_params = 0.0
    for block in model.blocks:
        original_attn_params += block.attn.cal_params()
        block.attn = Attention_SVD(atten=block.attn, rank_ratio=args.attn_compress_ratio, is_pretrained=True, attn2_with_bias=block.attn.qkv_bias,
                                      attn_drop=0, drop=args.drop)
        # if args.mlp_compress_ratio != 1:
        #     fc = block.mlp.fc1
        #     block.mlp.fc1 = SVDLinear(fc.in_features, fc.out_features, fc.bias is not None, dense_w=fc.weight.data,
        #                             dense_b=fc.bias.data, compression_ratio=args.mlp_compress_ratio)
        #     fc = block.mlp.fc2
        #     block.mlp.fc2 = SVDLinear(fc.in_features, fc.out_features, fc.bias is not None, dense_w=fc.weight.data,
        #                             dense_b=fc.bias.data, compression_ratio=args.mlp_compress_ratio)
        decompose_attn_params += block.attn.cal_params()
    original_attn_params = original_attn_params / 10**6
    decompose_attn_params = decompose_attn_params / 10**6
    relative_params  = round((1 - (decompose_attn_params / original_attn_params) ) * 100)
    print(Stdout_Font.Blue + f"Original MHA Params : {original_attn_params} M" + Stdout_Font.Reset)
    print(Stdout_Font.Blue + f"Decomposed MHA Params : {decompose_attn_params} M" + Stdout_Font.Reset)
    print(Stdout_Font.Blue + f"MHA Params Reduction : {relative_params} %" + Stdout_Font.Reset)
    model.to(device)
    return model

def comcat_decompose(model, device, args):
    from vision_transformer_svd import Attention_newSVD
    
    print(Stdout_Font.Yellow + "decomposing model by comcat..." + Stdout_Font.Reset)
    print(">>>> attn_compress_ratio : ", args.attn_compress_ratio)
    # print(">>>> mlp_compress_ratio : ", args.mlp_compress_ratio)
    
    original_attn_params = 0.0
    decompose_attn_params = 0.0
    for block in model.blocks:
        original_attn_params += block.attn.cal_params()
        block.attn = Attention_newSVD(atten=block.attn, rank_ratio=args.attn_compress_ratio, is_pretrained=True, attn2_with_bias=block.attn.qkv_bias,
                                      attn_drop=0, drop=args.drop)
        # if args.mlp_compress_ratio != 1:
        #     fc = block.mlp.fc1
        #     block.mlp.fc1 = SVDLinear(fc.in_features, fc.out_features, fc.bias is not None, dense_w=fc.weight.data,
        #                             dense_b=fc.bias.data, compression_ratio=args.mlp_compress_ratio)
        #     fc = block.mlp.fc2
        #     block.mlp.fc2 = SVDLinear(fc.in_features, fc.out_features, fc.bias is not None, dense_w=fc.weight.data,
        #                             dense_b=fc.bias.data, compression_ratio=args.mlp_compress_ratio)
        decompose_attn_params += block.attn.cal_params()
    original_attn_params = original_attn_params / 10**6
    decompose_attn_params = decompose_attn_params / 10**6
    relative_params  = round((1 - (decompose_attn_params / original_attn_params) ) * 100)
    print(Stdout_Font.Blue + f"Original MHA Params : {original_attn_params} M" + Stdout_Font.Reset)
    print(Stdout_Font.Blue + f"Decomposed MHA Params : {decompose_attn_params} M" + Stdout_Font.Reset)
    print(Stdout_Font.Blue + f"MHA Params Reduction : {relative_params} %" + Stdout_Font.Reset)
    model.to(device)
    return model

def decompose(model, device, args, qk_rank=[16, 64], vo_rank=[16, 64], mlp_rank=[272, 496]): 
    from vision_transformer_svd import Attention_UniSVD

    print(Stdout_Font.Yellow + "Decomposing model by our method..." + Stdout_Font.Reset)
    depth = 12
    qk_head_dim = generate_steps(qk_rank[0], qk_rank[1], depth)
    vo_head_dim = generate_steps(vo_rank[0], vo_rank[1], depth)

    # if args.decompose_mlp:
    #     mlp_dim = generate_steps(mlp_rank[0], mlp_rank[1], depth)

    original_attn_params = 0.0
    decompose_attn_params = 0.0
    for idx, block in enumerate(model.blocks):
        original_attn_params += block.attn.cal_params()
        block.attn = Attention_UniSVD(atten=block.attn, is_pretrained=True, attn2_with_bias=block.attn.qkv_bias, attn_drop=0, drop=args.drop,
                                       qk_head_dim=qk_head_dim[idx], vo_head_dim=vo_head_dim[idx],
                                      )
        # if args.decompose_mlp:
        #     fc = block.mlp.fc1
        #     block.mlp.fc1 = SVDLinear(fc.in_features, fc.out_features, fc.bias is not None, dense_w=fc.weight.data,
        #                             dense_b=fc.bias.data, compression_ratio=None, rank=mlp_dim[idx]) 
        #     _idx = depth-idx-1
        #     fc = block.mlp.fc2
        #     block.mlp.fc2 = SVDLinear(fc.in_features, fc.out_features, fc.bias is not None, dense_w=fc.weight.data,
        #                             dense_b=fc.bias.data, compression_ratio=None, rank=mlp_dim[_idx])
        
        decompose_attn_params += block.attn.cal_params()

    original_attn_params = round(original_attn_params / 10**6, 2)
    decompose_attn_params = round(decompose_attn_params / 10**6, 2)
    relative_params  = round((1 - (decompose_attn_params / original_attn_params) ) * 100)
    print(Stdout_Font.Blue + f"Original MHA Params : {original_attn_params} M" + Stdout_Font.Reset)
    print(Stdout_Font.Blue + f"Decomposed MHA Params : {decompose_attn_params} M" + Stdout_Font.Reset)
    print(Stdout_Font.Blue + f"MHA Params Reduction : {relative_params} %" + Stdout_Font.Reset)
    model.to(device)
    return model

def get_args_parser():
    parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
    parser.add_argument('--batch-size', default=64, type=int)
    parser.add_argument('--epochs', default=300, type=int)
    parser.add_argument('--bce-loss', action='store_true')
    parser.add_argument('--unscale-lr', action='store_true')

    # Model parameters
    parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--input-size', default=224, type=int, help='images input size')

    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')

    # Augmentation parameters
    parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
                        help='Color jitter factor (default: 0.3)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + \
                             "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')
    parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
    parser.set_defaults(repeated_aug=True)

    parser.add_argument('--train-mode', action='store_true')
    parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
    parser.set_defaults(train_mode=True)

    parser.add_argument('--ThreeAugment', action='store_true')  # 3augment

    parser.add_argument('--src', action='store_true')  # simple random crop

    # Decomposition Hyperparameters
    parser.add_argument('--qk_rank1', default=16, type=int, help='Hyperparameter for UniSVD') 
    parser.add_argument('--qk_rank2', default=32, type=int, help='Hyperparameter for UniSVD') 
    parser.add_argument('--vo_rank1', default=16, type=int, help='Hyperparameter for UniSVD')
    parser.add_argument('--vo_rank2', default=32, type=int, help='Hyperparameter for UniSVD')
    parser.add_argument('--attn_compress_ratio', default=1.1, type=float, help='Hyperparameter for the per-weight decomposition and the combined decompisition')
    # parser.add_argument('--mlp_compress_ratio', default=1, type=float)
    parser.add_argument('--decomposed', action='store_true', default=False, help='UniSVD (Ours)')
    parser.add_argument('--base_decomposed', action='store_true', default=False, help='Per-weight decomposition')
    parser.add_argument('--comcat_decomposed', action='store_true', default=False, help='Combined decomposition')

    # Dataset parameters
    parser.add_argument('--data-path', default='/database1/dataset/ImageNet/ILSVRC2012', type=str,
                        help='dataset path')
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19'],
                        type=str, help='Image Net dataset path')
    parser.add_argument('--inat-category', default='name',
                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                        type=str, help='semantic granularity')

    parser.add_argument('--output_dir', default='./',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--load', default='', help='only load model weights from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--use-attn2', action='store_true', default=False)
    # parser.add_argument('--decompose_mlp', action='store_true', default=False)
    parser.add_argument('--original_eval', action='store_true', default=False, help='Evaluate the original model')
    parser.add_argument('--decompose-pretrained-model', default='', help='decomposed pretrained model from checkpoint')
    parser.add_argument('--attn2-with-bias', action='store_true', default=False)
    parser.add_argument('--with-align', action='store_true', default=False)
    parser.add_argument('--batch-size-search', default=128, type=int, help='the batch size when searching ranks')
    parser.add_argument('--beta-search', default=1.5, type=float, help='beta')
    parser.add_argument('--target-params-reduction', default=0.5, type=float,
                        help='expected percentage of parameter reduction')
    return parser


def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_val, args.nb_classes = build_dataset(is_train=False, args=args)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val,
        batch_size=int(1.5 * args.batch_size),
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )

    print(f"Creating model: {args.model}")

    model = create_model(
        args.model,
        pretrained=True, 
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
        img_size=args.input_size,
        use_attn2=False, 
    )
    if args.comcat_decomposed:
        model.convert2attn2()

    for name_p, p in model.named_parameters():
        p.requires_grad = False

    model.to(device)

    print(Stdout_Font.Green + ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Original <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + Stdout_Font.Reset)
    original_params = sum(param.numel() for param in model.parameters()) / 10**6
    print(f'-------------------------------- Original Model Params : {original_params:.2f} M --------------------------------')
    calculate_flops(model, input_shape=(1, 3, args.input_size, args.input_size))

    if args.eval:
        if args.original_eval:
            test_stats = evaluate(data_loader_val, model, device)
            print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")

        if args.comcat_decomposed:
            model = comcat_decompose(model, device, args)
        elif args.base_decomposed:
            model = base_decompose(model, device, args)
        elif args.decomposed:
            model = decompose(model, device, args, qk_rank=[args.qk_rank1, args.qk_rank2], vo_rank=[args.vo_rank1, args.vo_rank2])
        for name_p, p in model.named_parameters():
            p.requires_grad = False
        decomposed_params = sum(param.numel() for param in model.parameters()) / 10**6
        print(Stdout_Font.Green + ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Decomposition <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" + Stdout_Font.Reset)
        print()
        print(f'-------------------------------- Decomposed Model Params : {decomposed_params:.2f} M --------------------------------')
        calculate_flops(model, input_shape=(1, 3, args.input_size, args.input_size))
    
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")

        return


if __name__ == '__main__':
    parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)


