import os
import argparse
import torch
import torch.nn as nn
from tqdm import tqdm, trange
from vl_distill_utils import load_or_process_file
from epoch import epoch, epoch_test, itm_eval
import copy
import warnings
import datetime
from data import get_dataset_flickr, textprocess, textprocess_train
from networks import CLIPModel_linear
import numpy as np


warnings.filterwarnings("ignore", category=DeprecationWarning)

def main(args):
    if args.dataset == 'flickr':
        args.image_root = '/root/autodl-tmp/data/Flickr30k'
        args.ann_root = '/root/autodl-tmp/data/Flickr30k_ann'
    elif args.dataset == 'coco':
        args.image_root = '/root/autodl-tmp/data/COCO'
        args.ann_root = '/root/autodl-tmp/data/Flickr30k_ann'
    else:
        return
    
    args.dsa = True if args.dsa == 'True' else False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.distributed = torch.cuda.device_count() > 1
    
    print('Hyper-parameters: \n', args.__dict__)

    save_dir = os.path.join(args.buffer_path, args.dataset)
    save_dir = os.path.join(save_dir, args.image_encoder+"_"+args.text_encoder, args.loss_type)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
        
    ''' organize the datasets '''
    trainloader, testloader, train_dataset, test_dataset = get_dataset_flickr(args)
    
    train_sentences = train_dataset.get_all_captions() 
    _ = load_or_process_file('text', textprocess, args, testloader)
    _ = load_or_process_file('train_text', textprocess_train, args, train_sentences)

    data = np.load(f'statistics/{args.dataset}_{args.text_encoder}_text_embed.npz')
    bert_test_embed_loaded = data['bert_test_embed']
    bert_test_embed = torch.from_numpy(bert_test_embed_loaded).cpu()

    net = CLIPModel_linear(args).to(args.device)

    img_net = net.image_projection
    txt_net = net.text_projection

    img_net.train()
    txt_net.train()

    # weights = [*img_net.parameters(), *txt_net.parameters()]
    # param_groups = [dict(params=weights, use_muon=True, lr=0.01, weight_decay=0.0005)]
    # optimizer = SingleDeviceMuon(param_groups)
    # optimizer.zero_grad()

    optimizer = torch.optim.SGD([
        {'params': img_net.parameters(), 'lr': args.lr_net_img},
        {'params': txt_net.parameters(), 'lr': args.lr_net_txt},
    ], lr=0, momentum=0.9)
 
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.train_epochs//2 + 1], gamma=0.1)

    for e in trange(args.train_epochs):
        train_loss, train_acc = epoch(e, trainloader, net, optimizer, args)
        scheduler.step()
        
        score_val_i2t, score_val_t2i = epoch_test(testloader, net, args.device, bert_test_embed)
        val_result = itm_eval(score_val_i2t, score_val_t2i, testloader.dataset.txt2img, testloader.dataset.img2txt)  

        print("Epoch={} Train Acc={} | Img R@1={} R@5={} R@10={} | Txt R@1={} R@5={} R@10={} | R@Mean={}".format(
            e, train_acc,
            val_result['img_r1'], val_result['img_r5'], val_result['img_r10'],
            val_result['txt_r1'], val_result['txt_r5'], val_result['txt_r10'], val_result['r_mean'])) 
        
        n = 0
        while os.path.exists(os.path.join(save_dir, "pretrain_{}.pt".format(n))):
            n += 1
        torch.save(net.state_dict(), os.path.join(save_dir, "pretrain_{}.pt".format(n)))


def make_buffer_parser():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='flickr', choices=['flickr', 'coco'], help='dataset')    
    parser.add_argument('--lr_net_img', type=float, default=0.1, help='learning rate for updating network parameters')
    parser.add_argument('--lr_net_txt', type=float, default=0.1, help='learning rate for updating network parameters')
    
    # parser.add_argument('--batch_size_train', type=int, default=1024, help='batch size for training networks')
    # parser.add_argument('--batch_size_test', type=int, default=256, help='batch size for evaluating networks')
    
    parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
                        help='whether to use differentiable Siamese augmentation.')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
                        help='differentiable Siamese augmentation strategy')

    parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')
    parser.add_argument('--train_epochs', type=int, default=10)
    parser.add_argument('--zca', action='store_true')
    parser.add_argument('--decay', action='store_true')
    
    parser.add_argument('--mom', type=float, default=0.5, help='momentum')
    parser.add_argument('--l2', type=float, default=0, help='l2 regularization') 
    parser.add_argument('--save_interval', type=int, default=10)
    current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    parser.add_argument('--name', type=str, default=current_time, help='name of wandb run')

    parser.add_argument('--text_pretrained', type=bool, default=True, help='text_pretrained')
    parser.add_argument('--image_pretrained', type=bool, default=True, help='image_pretrained')

    parser.add_argument('--text_trainable', type=bool, default=False, help='text_trainable')
    parser.add_argument('--image_trainable', type=bool, default=False, help='image_trainable') 

    parser.add_argument('--projection_dim', type=int, default=512, help='dimension of projection head')

    parser.add_argument('--batch_size_train', type=int, default=128, help='batch_size_train')
    parser.add_argument('--batch_size_test', type=int, default=128, help='batch_size_test')
    
    parser.add_argument('--image_size', type=int, default=224, help='image_size')
    parser.add_argument('--k_test', type=int, default=128, help='k_test')
    parser.add_argument('--load_npy', type=bool, default=False, help='load_npy')
    
    parser.add_argument('--image_encoder', type=str, default='nfnet', help='image encoder')
    #, choices=['nfnet', 'resnet18_gn', 'vit_tiny', 'nf_resnet50', 'nf_regnet'])
    parser.add_argument('--text_encoder', type=str, default='bert', choices=['bert', 'clip', 'distilbert','gpt1'], help='text encoder')
    
    parser.add_argument('--margin', default=0.2, type=float,
                        help='Rank loss margin.')
    parser.add_argument('--measure', default='cosine',
                    help='Similarity measure used (cosine|order)')
    parser.add_argument('--max_violation', action='store_true',
                        help='Use max instead of sum in the rank loss.')
    parser.add_argument('--only_has_image_projection', type=bool, default=False, help='None')
    parser.add_argument('--grounding', type=bool, default=False, help='None')
    
    parser.add_argument('--distill', type=bool, default=False, help='whether distill')
    parser.add_argument('--loss_type', type=str, default="InfoNCE")

    parser.add_argument('--eval_freq', type=int, default=5, help='eval_freq')
    parser.add_argument('--no_aug', action='store_true', help='no_aug')
    parser.add_argument('--skip_save', action='store_true', help='skip save buffer')
    parser.add_argument('--disabled_wandb', type=bool, default=True)
    return parser


if __name__ == '__main__':
    parser = make_buffer_parser()
    args = parser.parse_args()

    args.image_root = {
        'flickr': "distill_utils/data/Flickr30k/",
        'coco': "distill_utils/data/COCO/",
    }[args.dataset]

    main(args)

