import time
import torch
import random
import datetime
import argparse
import numpy as np
from pathlib import Path
import timm.optim.optim_factory as optim_factory

import util.misc as misc
from logs.configuration import *
# from dataloader import prepare_data, prepare_finetune_data

import models_finetune
from engine_finetune import train_one_epoch, eval_model

import torch.nn as nn

from dataloader import SUNDataset, DTDDataset
from torchvision import transforms


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--experiment', type=str, default='ft')
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus: [32*3, 64*3, 12*3, 12*3]')
    parser.add_argument('--epochs', default=500, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # ViT loss function
    parser.add_argument('--smoothing', type=float, default=0.0, help='Label smoothing (default: 0.1)')

    # training stage
    parser.add_argument('--is_linear', default=False, action='store_true', help="the type of the classifier")
    parser.add_argument('--augment', default=False, action='store_true', help='whether augment images')
    parser.add_argument('--sample_pos', default=False, action="store_true")
    parser.add_argument('--is_mix', default=False, action="store_true")
    parser.add_argument('--recon_mission', default=False, action='store_true')
    parser.add_argument('--global_pool', default=False, action='store_true')
    parser.add_argument('--linear_probe', default=False, action='store_true')

    # Hyperparameters of loss function 
    parser.add_argument('--lambda1', type=float, default=1e-3, help='hyperparameter of adaptive contrastive loss')
    parser.add_argument('--lambda2', type=float, default=1.0, help='hyperparameter of classification loss')
    parser.add_argument('--tao', type=float, default=4e-1, help='temperature coefficient of infoNCE loss')
    parser.add_argument('--margin', type=float, default=8e-3, help='margin of reconstruction L2 loss')
    parser.add_argument('--ada_iter', type=int, default=0, help='iteration for adaptive contrastive loss')
    parser.add_argument('--decoder_drop', type=float, default=0.0, help="Dropout rate of the decoder")
    parser.add_argument('--inter_ada', type=int, default=25, help="training interval of the domain classifier")
    parser.add_argument('--max_epoch_ada', type=int, default=100, help="maximum adaptive training epoch")

    # Model parameters
    parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--input_size', default=224, type=int, help='images input size')
    parser.add_argument('--num_classes', default=20, type=int, help='images input size')
    parser.add_argument('--mask_ratio', default=0.75, type=float, help='Masking ratio (percentage of removed patches).')
    parser.add_argument('--norm_pix_loss', action='store_true',
                        help='Use (per-patch) normalized pixels as targets for computing loss')
    parser.set_defaults(norm_pix_loss=False)

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)')
    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr): [0.025, 0.05, 5e-5, 5e-5]')
    parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N', help='epochs to warmup LR')

    parser.add_argument('--switch_epochs', type=int, default=20, metavar='N', help='epochs to warmup LR')
    # Dataset parameters
    parser.add_argument('--trial_seed', type=int, default=0,
                        help='Trial number (used for seeding split_dataset and random_hparams).')
    parser.add_argument('--val_fraction', default=0.01, type=float,
                        help="fraction of labeled training data: [1%, 5%, 10%, 100%]")
    parser.add_argument('--holdout_fraction', type=float, default=0.01,
                        help='fraction of validation set/(validation set+training set): [0.991, 0.955, 0.91, 0.10]')
    parser.add_argument('--uda_holdout_fraction', type=float, default=0)
    parser.add_argument('--test_envs', type=int, nargs='+', default=[0])
    parser.add_argument('--dataset', default='DomainNet', type=str, help='dataset')
    parser.add_argument('--data_path', default='../dataset', type=str, help='dataset path')
    parser.add_argument('--output_dir', default='../output_dir', help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='../output_dir', help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda', help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='/path/to/the/checkpoints', help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    return parser


class SphereSampler(nn.Module):
    def __init__(self, shape=512, init_aug=0.5):
        """
        用于在超球面上学习 std（采样标准差）的神经网络模块
        :param init_std: std 初始值
        """
        super(SphereSampler, self).__init__()
        self.mask_ratio = nn.Parameter(torch.tensor(init_aug))

    def forward(self, x, num_samples=1):
        """
        采样位于单位超球面上的点
        :param center_vectors: (N, D) 单位归一化的中心向量
        :param num_samples: 采样点数
        :return: (N, num_samples, D) 采样结果, 以及学习到的 std
        """
        mask_prob = torch.sigmoid(self.mask_ratio * 5.0)  # 增强 mask 的二值化倾向
        hard_mask = (mask_prob > 0.5).float()  # 通过 0.5 阈值硬化 mask
        print(self.mask_ratio)
        mask = hard_mask + mask_prob - mask_prob.detach()
        masked_tensor = x * mask
        return masked_tensor  # (N, num_samples, D)


class ModifiedMultiHeadAttention(nn.Module):
    def __init__(self, original_mha, encoder_type='titan', init_aug=0.5):
        super().__init__()
        self.original_mha = original_mha  # 原始 Attention 层

        for param in self.original_mha.parameters():
            param.requires_grad = False

        # if encoder_type == 'vit_base':
        #     print(original_mha)
        # shape = original_mha.mlp.fc2.out_features
        self.sample_aug = SphereSampler(shape=768, init_aug=init_aug)

    def forward(self, hidden_states, attention_mask=None):
        hidden_states = self.original_mha(hidden_states)
        if self.training:
            out = self.sample_aug(hidden_states)
            return out
        else:
            return hidden_states


class CLIPPredictor(nn.Module):
    """ 
        Vision Transformer with support for global average pooling
    """

    def __init__(self, clip, dim=512, out_size=20):
        super().__init__()

        self.enc = clip
        self.out = nn.Sequential(
            nn.Linear(dim, int(dim * 2)),
            nn.Linear(int(dim * 2), out_size)
        )

        def init_linear_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

        init_linear_weights(self.out)

    def forward(self, x):
        x = self.enc(x)
        x_ = x
        x_ = self.out(x_)
        return x_


def prepare_model(model):
    # class_num = 397
    class_num = 47
    model = CLIPPredictor(model, dim=768, out_size=class_num)
    block_final = model.enc.blocks[-1]

    # for i, block in enumerate(model.enc.blocks):
    for p in model.enc.parameters():
        p.requires_grad = False

    for i, block in enumerate(model.enc.blocks):
        if hasattr(block, "norm1"):
            block.norm1.requires_grad_(True)
            block.norm2.requires_grad_(True)
    for p in model.out.parameters():
        p.requires_grad = True

    model.enc.blocks[-1] = ModifiedMultiHeadAttention(block_final, encoder_type="vit_base", init_aug=0.01)
    return model


def main(args):
    configure_experiment(args.output_dir, rank=misc.get_rank())
    logger = get_logger()
    logger.info(args)
    args.model = "Finetune"
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
    for param in model.parameters():
        param.requires_grad = False

    model = prepare_model(model)

    device = torch.device(args.device)
    model.to(device)

    # fix the seed for reproducibility
    seed = args.seed
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    best_perf = 0
    no_impr_counter = 0
    record_epoch = None

    optimizer_enc = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.enc.parameters()), lr=args.lr,
                                      betas=(0.9, 0.95), weight_decay=args.weight_decay)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.out.parameters()), lr=args.lr,
                                  betas=(0.9, 0.95), weight_decay=args.weight_decay)

    start_time = time.time()
    logger.info("Start training for {} epochs".format(args.epochs))
    logger.info("Parameters {:.2f}M".format(sum([x.numel() for x in model.parameters() if x.requires_grad]) / 1e6))

    preprocess = transforms.Compose([
        transforms.Resize((224, 224), ),
        transforms.ToTensor(),
        transforms.Normalize(  # Normalize with CLIP's mean and std
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711))
    ])

    trainset = DTDDataset(image_split="TRAIN", ratio=(1 - args.holdout_fraction), transform=preprocess)
    valset = DTDDataset(image_split="TEST", ratio=(1 - args.holdout_fraction), transform=preprocess)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4,
                                              pin_memory=True,
                                              drop_last=True)
    valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=4,
                                            pin_memory=True,
                                            drop_last=False)
    # train_iterator = trainloader
    total_len = len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        logger.info("Epoch:[{}/{}], Lr:{:.5f}".format(epoch, args.epochs, args.lr))
        # train_one_epoch(model, train_iterator, optimizer, epoch, args, total_len, logger)

        if (epoch // args.switch_epochs) % 2 == 0:
            print('training classifer')
            train_one_epoch(model, trainloader, optimizer, epoch, args, total_len, logger)
        else:
            print('training encoder')
            train_one_epoch(model, trainloader, optimizer_enc, epoch, args, total_len, logger)

    current_perf = eval_model(model, valloader, device, epoch, args, total_len, logger, )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {} | Best perf in validation set: {:.2f}% in epoch {}'.format(total_time_str, best_perf,
                                                                                             record_epoch))
    misc.save_model(args, epoch, model, optimizer, prefix=int(args.holdout_fraction * 100))
    logger.info("+" * 50)


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(exist_ok=True)

    main(args)
