import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from timm import create_model
import time
import os
import copy
import argparse
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import random
import numpy as np

from custommodels import *
from local import LocalLoss
from adv_generator import *
from config import config_dict
from model import get_custom_model

def arg_parse(config = None):
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--gpu', type=str, default='0,1', help='gpu device.')
    parser.add_argument('--source_model_name', type=str, default='resnet50', help='')
    parser.add_argument('--witness_model_name', type=str, default='vit_base_patch16_224', help='')
    parser.add_argument('--warmup_epochs', type=int, default=5, help='') #: CNN 5 ViT 30 SWIN 20
    parser.add_argument('--initial_lr', type=float, default=0.1, help='') #: CNN 0.1 ViT 0.003 SWIN 0.001
    parser.add_argument('--batch_size', type=int, default=128, help='') #: CNN 128 ViT 512 SWIN 1024
    parser.add_argument('--T', type=float, default=1., help='')
    parser.add_argument('--adv', type=float, default=1e-2 * 2, help='')
    parser.add_argument('--local', type=float, default=1e-1 * 2, help='')
    parser.add_argument('--adv_local', type=float, default=1e-2 * 2, help='')
    parser.add_argument('--optimizer', type=str, default='sgd', help='')
    
    if config:
        parser.set_defaults(**config)
    
    args = parser.parse_args()
    args.opt_path = os.path.join(f'./checkpionts', f'{args.source_model_name}_{args.witness_model_name}')
   
    if not os.path.exists(args.opt_path):
        os.makedirs(args.opt_path)
    return args

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

ViT_model_names = [
    'vit_tiny_patch16_224', # ViT-Tiny/16
    'vit_small_patch16_224', # ViT-Small/16
    'vit_base_patch16_224', # ViT-Base/16
    'swin_tiny_patch4_window7_224' # Swin-Tiny
]

def set_seed(seed):
    random.seed(seed)  # Python的随机库
    np.random.seed(seed)  # Numpy库
    torch.manual_seed(seed)  # 为CPU设置种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 为当前CUDA设备设置种子
        torch.cuda.manual_seed_all(seed)  # 为所有CUDA设备设置种子
        # 下面两行增强复现性，可能影响性能
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def train_model(model,witness_model, local_criterion, KL_criterion, optimizer, num_epochs=1, start_epoch=0, args=None):
        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0
        T = args.T
        adv_alpha = args.adv 
        adv_local_alpha = args.adv_local 
        local_alpha = args.local 

        for epoch in range(num_epochs):
            #: 内层加入对抗样本(kl loss)，外层将source(clean) 和 source(adv) 都和 witness(clean) 拉近：kl loss 和 dense loss
            start_time = time.time()

            print(f'Epoch {epoch+1}/{num_epochs}')
            print('-' * 10)

            running_loss = 0.0
            running_corrects = 0
            total_samples = 0

            train_progress_bar = tqdm(dataloaders['train'], desc=f'Epoch {epoch+1} Train', leave=False)

            for step,(train_batch_data, train_batch_labels) in enumerate(train_progress_bar):
                train_batch_data = train_batch_data.cuda()
                train_batch_labels = train_batch_labels.cuda()

                optimizer.zero_grad()
                source_output = model(train_batch_data)
                
                if not isinstance(source_output, torch.Tensor):
                    _, preds = torch.max(source_output[0], 1)
                else:
                    _, preds = torch.max(source_output, 1)
                
                
                with torch.no_grad():
                    witness_output = witness_model(train_batch_data)

                if not isinstance(source_output, torch.Tensor):
                    source_logits, _ = source_output
                else:
                    source_logits = source_output
                
                if not isinstance(witness_output, torch.Tensor):
                    witness_logits, _ = witness_output
                else:
                    witness_logits = witness_output

                source_adv_output = adv_generator(model=model, teacher_logits=witness_logits, x_natural=train_batch_data,y=train_batch_labels, optimizer=optimizer,step_size=2/255.0,epsilon=8/255.0,perturb_steps=10)
                
                if not isinstance(source_adv_output, torch.Tensor):
                    source_adv_logits, _ = source_adv_output
                
                model.train()

                loss_global = KL_criterion(F.log_softmax(source_logits / T, dim=1), F.softmax(witness_logits / T, dim=1)) * (T ** 2)
                    
                loss_local = local_criterion(source_output, train_batch_labels, epoch, num_epochs, witness_output)
                loss_adv_global = KL_criterion(F.log_softmax(source_adv_logits / T, dim=1), F.softmax(witness_logits / T, dim=1)) * (T ** 2)

                loss_adv_local = local_criterion(source_adv_output, train_batch_labels, epoch, num_epochs, witness_output)
                
                # saa loss
                loss = loss_global + local_alpha * loss_local + adv_alpha * (loss_adv_global + adv_local_alpha * loss_adv_local)

                loss.backward()
                if args.source_model_name in ViT_model_names:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                optimizer.step()
                        
                current_loss = loss.item() * train_batch_data.size(0)
                running_loss += current_loss
                running_corrects += torch.sum(preds == train_batch_labels.data)
                total_samples += train_batch_data.size(0)

                # 更新tqdm的描述以显示当前的平均损失
                train_progress_bar.set_description(f"Epoch {epoch+1} Train Loss: {running_loss/total_samples:.4f}")

            # Epoch结束后输出平均损失和准确率
            epoch_loss = running_loss / total_samples
            epoch_acc = running_corrects.double() / total_samples
            print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            #: 保存每个epoch的模型
            torch.save(model.state_dict(), os.path.join(args.opt_path,f'{args.source_model_name}_{epoch+start_epoch+1}.pth'))
            print(f'Successfully save model epoch {epoch+start_epoch + 1}')

            end_time = time.time()
            print(f"Epoch duration: {end_time - start_time:.2f} seconds")
            print()

        return model



if __name__ == "__main__":
    args = arg_parse(config = config_dict)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    set_seed(42)

    print(f'checkpoint saving at {args.opt_path}')

    data_dir = '/DATACENTER/ImageNet/' # ImgaeNet dataset
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                    for x in ['train', 'val']}
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True, num_workers=4)
                for x in ['train', 'val']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    class_names = image_datasets['train'].classes


    source_model = get_custom_model(model_name=args.source_model_name, num_classes=len(class_names))
    witness_model = get_custom_model(model_name=args.witness_model_name, num_classes=len(class_names))


    if torch.cuda.device_count() > 1:
        print(f"Let's use {torch.cuda.device_count()} GPUs!")
        source_model = nn.DataParallel(source_model)
        witness_model = nn.DataParallel(witness_model)

    source_model = source_model.cuda()
    witness_model = witness_model.cuda()

    #: loss修改
    CE_criterion = nn.CrossEntropyLoss()
    KL_criterion = nn.KLDivLoss(reduction='batchmean')
    local_criterion = LocalLoss(
        KL_criterion
    )

    #: 优化器修改
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(source_model.parameters(), lr=args.initial_lr, momentum=0.9) #: CNN 0.1 ViT 0.003 SWIN 0.001
    else:
        optimizer = optim.AdamW(source_model.parameters(), lr=args.initial_lr) #: CNN 0.1 ViT 0.0003 SWIN 0.0001
    
    
    #: warmup 使用余弦衰减（cosine decay）调度
    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)

    #: 更新source_model的参数
    source_model.train()
    witness_model.eval()


    model = train_model(source_model,witness_model, local_criterion,KL_criterion, optimizer, num_epochs=9,start_epoch=0, args = args)
