import time
import torch
import random
import datetime
import argparse
import numpy as np
from pathlib import Path
import timm.optim.optim_factory as optim_factory

import util.misc as misc
from logs.configuration import *

import models_finetune
from models_finetune import VisionTransformer
from engine_finetune import train_one_epoch, eval_model
import torch.nn as nn
import copy
from dataloader import SUNDataset, DTDDataset


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--experiment', type=str, default='ft')
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus: [32*3, 64*3, 12*3, 12*3]')
    parser.add_argument('--epochs', default=500, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # ViT loss function
    parser.add_argument('--smoothing', type=float, default=0.0, help='Label smoothing (default: 0.1)')

    # training stage
    parser.add_argument('--is_linear', default=False, action='store_true', help="the type of the classifier")
    parser.add_argument('--augment', default=False, action='store_true', help='whether augment images')
    parser.add_argument('--sample_pos', default=False, action="store_true")
    parser.add_argument('--is_mix', default=False, action="store_true")
    parser.add_argument('--recon_mission', default=False, action='store_true')
    parser.add_argument('--global_pool', default=False, action='store_true')
    parser.add_argument('--linear_probe', default=False, action='store_true')

    # Hyperparameters of loss function 
    parser.add_argument('--lambda1', type=float, default=1e-3, help='hyperparameter of adaptive contrastive loss')
    parser.add_argument('--lambda2', type=float, default=1.0, help='hyperparameter of classification loss')
    parser.add_argument('--tao', type=float, default=4e-1, help='temperature coefficient of infoNCE loss')
    parser.add_argument('--margin', type=float, default=8e-3, help='margin of reconstruction L2 loss')
    parser.add_argument('--ada_iter', type=int, default=0, help='iteration for adaptive contrastive loss')
    parser.add_argument('--decoder_drop', type=float, default=0.0, help="Dropout rate of the decoder")
    parser.add_argument('--inter_ada', type=int, default=25, help="training interval of the domain classifier")
    parser.add_argument('--max_epoch_ada', type=int, default=100, help="maximum adaptive training epoch")
    parser.add_argument('--switch_epochs', type=int, default=20, metavar='N', help='epochs to warmup LR')
    # Model parameters
    parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--input_size', default=224, type=int, help='images input size')
    parser.add_argument('--num_classes', default=20, type=int, help='images input size')
    parser.add_argument('--mask_ratio', default=0.75, type=float, help='Masking ratio (percentage of removed patches).')
    parser.add_argument('--norm_pix_loss', action='store_true',
                        help='Use (per-patch) normalized pixels as targets for computing loss')
    parser.set_defaults(norm_pix_loss=False)

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)')
    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr): [0.025, 0.05, 5e-5, 5e-5]')
    parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N', help='epochs to warmup LR')

    # Dataset parameters
    parser.add_argument('--trial_seed', type=int, default=0,
                        help='Trial number (used for seeding split_dataset and random_hparams).')
    parser.add_argument('--val_fraction', default=0.01, type=float,
                        help="fraction of labeled training data: [1%, 5%, 10%, 100%]")
    parser.add_argument('--holdout_fraction', type=float, default=0.01,
                        help='fraction of validation set/(validation set+training set): [0.991, 0.955, 0.91, 0.10]')
    parser.add_argument('--uda_holdout_fraction', type=float, default=0)
    parser.add_argument('--test_envs', type=int, nargs='+', default=[0])
    parser.add_argument('--dataset', default='DomainNet', type=str, help='dataset')
    parser.add_argument('--data_path', default='../dataset', type=str, help='dataset path')
    parser.add_argument('--output_dir', default='../output_dir', help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='../output_dir', help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda', help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='/path/to/the/checkpoints', help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)
    return parser


class SphereSampler(nn.Module):
    def __init__(self, shape=512, init_aug=0.5):
        """
        用于在超球面上学习 std（采样标准差）的神经网络模块
        :param init_std: std 初始值
        """
        super(SphereSampler, self).__init__()
        self.mask_ratio = nn.Parameter(torch.tensor(init_aug))

    def forward(self, x, num_samples=1):
        """
        采样位于单位超球面上的点
        :param center_vectors: (N, D) 单位归一化的中心向量
        :param num_samples: 采样点数
        :return: (N, num_samples, D) 采样结果, 以及学习到的 std
        """
        mask_prob = torch.sigmoid(self.mask_ratio * 5.0)  # 增强 mask 的二值化倾向
        hard_mask = (mask_prob > 0.5).float()  # 通过 0.5 阈值硬化 mask
        print(self.mask_ratio)
        mask = hard_mask + mask_prob - mask_prob.detach()
        masked_tensor = x * mask
        return masked_tensor  # (N, num_samples, D)


class ModifiedMultiHeadAttention(nn.Module):
    def __init__(self, original_mha, encoder_type='titan', init_aug=0.5):
        super().__init__()
        self.original_mha = original_mha  # 原始 Attention 层

        for param in self.original_mha.parameters():
            param.requires_grad = False

        self.sample_aug = SphereSampler(shape=768, init_aug=init_aug)

    def forward(self, hidden_states, attention_mask=None):
        hidden_states = self.original_mha(hidden_states)
        if self.training:
            out = self.sample_aug(hidden_states)
            return out
        else:
            return hidden_states


class CLIPPredictor(nn.Module):
    """ 
        Vision Transformer with support for global average pooling
    """

    def __init__(self, clip, dim=512, out_size=20):
        super().__init__()

        self.enc = clip
        self.out = nn.Sequential(
            nn.Linear(dim, int(dim * 2)),
            nn.Linear(int(dim * 2), out_size)
        )

    def forward(self, x):
        x = self.enc.encode_image(x)
        x_ = x
        x_ = self.out(x_)
        return x_


def prepare_model(model):
    class_num = 47
    model = CLIPPredictor(model, dim=512, out_size=class_num)
    block_final = model.enc.visual.transformer.resblocks[-1]

    for p in model.enc.visual.parameters():
        p.requires_grad = False

    for i, block in enumerate(model.enc.visual.transformer.resblocks):
        if hasattr(block, "ln_1"):
            block.ln_1.requires_grad_(True)
            block.ln_2.requires_grad_(True)
    for p in model.out.parameters():
        p.requires_grad = True

    return model


def weight_filter(M, i=0.1):
    print(i)
    M = M * i
    print(M)
    return M * i


def trim_encoder(ori_model, model, trim_ratio=0.25):
    ori_weights = []
    ori_bias = []
    for i, block in enumerate(ori_model.visual.transformer.resblocks):
        if hasattr(block, "ln_1"):
            ori_weights.append(block.ln_1.weight.detach().cpu().numpy())
            ori_bias.append(block.ln_1.bias.detach().cpu().numpy())
            ori_weights.append(block.ln_2.weight.detach().cpu().numpy())
            ori_bias.append(block.ln_2.bias.detach().cpu().numpy())
    print(f'--------------------loaded original layernorm weights--------------------')

    current_weights = []
    current_bias = []
    for i, block in enumerate(model.enc.visual.transformer.resblocks):
        if hasattr(block, "ln_1"):
            current_weights.append(block.ln_1.weight.detach().cpu().numpy())
            current_bias.append(block.ln_1.bias.detach().cpu().numpy())
            current_weights.append(block.ln_2.weight.detach().cpu().numpy())
            current_bias.append(block.ln_2.bias.detach().cpu().numpy())

    print(f'--------------------loaded tuned layernorm weights--------------------')

    filtered_weight = weight_filter(np.vstack(current_weights) - np.vstack(ori_weights)[:len(current_weights)],
                                    i=trim_ratio)
    # filtered_bias = weight_filter(np.vstack(current_bias)-np.vstack(ori_bias)[:len(current_weights)], i=trim_ratio)
    # filtered_weight = np.vstack(current_weights)-np.vstack(ori_weights)[:len(current_weights)]
    filtered_bias = np.vstack(current_bias) - np.vstack(ori_bias)[:len(current_weights)]
    # print(filtered_bias)

    filtered_weight = filtered_weight + np.vstack(ori_weights)[:len(current_weights)]
    filtered_bias = filtered_bias + np.vstack(ori_bias)[:len(current_weights)]
    count = 0

    for i, block in enumerate(model.enc.visual.transformer.resblocks):
        if hasattr(block, "ln_1"):
            block.ln_1.weight.data = torch.tensor(filtered_weight[count * 2]).float()
            block.ln_1.bias.data = torch.tensor(filtered_bias[count * 2]).float()
            block.ln_2.weight.data = torch.tensor(filtered_weight[count * 2 + 1]).float()
            block.ln_2.bias.data = torch.tensor(filtered_bias[count * 2 + 1]).float()
            count += 1
    print(f'--------------------loaded trimed layernorm weights--------------------')
    return model


def main(args):
    configure_experiment(args.output_dir, prefix='_trim_scale_only_weight', rank=misc.get_rank())
    logger = get_logger()
    logger.info(args)

    args.model = "Finetune"
    # model = models_finetune.__dict__[args.model](args=args, norm_pix_loss=args.norm_pix_loss)
    import open_clip
    # ori_model, _ = clip.load("ViT-B/32", device='cpu')
    ori_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    for p in ori_model.parameters():
        p.requires_grad = False

    tuned_model = prepare_model(copy.deepcopy(ori_model))
    ckpt_path = f"{args.output_dir}/checkpoint-{int(args.holdout_fraction * 100)}{args.epochs - 1}.pth"
    state_dict = torch.load(ckpt_path)['model']
    tuned_model.load_state_dict(state_dict, strict=True)

    seed = args.seed
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device(args.device)

    trainset = DTDDataset(image_split="TRAIN", ratio=(1 - args.holdout_fraction), transform=preprocess)
    valset = DTDDataset(image_split="TEST", ratio=(1 - args.holdout_fraction), transform=preprocess)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4,
                                              pin_memory=True,
                                              drop_last=True)
    valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=4,
                                            pin_memory=True,
                                            drop_last=False)
    # train_iterator = trainloader
    total_len = len(valloader)

    for trim_ratio in [2.0, 1.5, 1.3, 1.2, 1.1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
        # for trim_ratio in [ 0.]:
        tmp_model = copy.deepcopy(tuned_model)
        tmp_model2 = copy.deepcopy(ori_model)
        model = trim_encoder(tmp_model2, tmp_model, trim_ratio=trim_ratio)
        model.cuda()
        for p in model.parameters():
            p.requires_grad = False

        with torch.no_grad():
            logger.info(
                f'trim ratio {trim_ratio},  \t {args.output_dir}/checkpoint-{int(args.holdout_fraction * 100)}{args.epochs - 1}.pth"')
            current_perf = eval_model(model, valloader, device, args.epochs, args, total_len, logger, )


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(exist_ok=True)

    main(args)
