import os
import argparse
import torch
import json
from data_processor import COCODatasetProcessor, FlickrDatasetProcessor, RSICDDatasetProcessor
from models.bidirectional_mlp import train_bidirectional_mlp
from models.generator import RAGGenerator
from models.clip_model import train_clip_model as train_albef  
from utils.knowledge_base import build_knowledge_base, build_cc3m_knowledge_base, build_flickr30k_knowledge_base
from utils.dataset import augment_dataset_img_txt_easy
from utils.evaluation import evaluate_retrieval_clip
from utils.easy import AlignParams, run_easy_alignment
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    if args.dataset=='MScoco':
        processor = COCODatasetProcessor(data_dir = args.data_dir)
        if not os.path.exists(os.path.join(args.data_dir, 'annotations')) or \
        not os.path.exists(os.path.join(args.data_dir, 'train2017')):
            print("COCO dataset not found. Downloading...")
            processor.download_coco_subset()
        else:
            print("COCO dataset found. Skipping download.")
        external_path = "./cc3m_kb/cc3m_kb.pt"
        out_external_easy=f"./cc3m_kb/cc3m_kb_easy_{args.dataset}_{args.img_only_ratio}.pt"
    elif args.dataset=='RSICD':
        processor = RSICDDatasetProcessor(data_dir= args.data_dir)
        external_path = './nwpu_kb/nwpu_kb.pt'
        out_external_easy=f"./nwpu_kb/nwpu_kb_easy_{args.dataset}_{args.img_only_ratio}.pt"  
    else:
        processor = FlickrDatasetProcessor(data_dir = args.data_dir)
        external_path = "./cc3m_kb/cc3m_kb.pt"
        out_external_easy=f"./cc3m_kb/cc3m_kb_easy_{args.dataset}_{args.img_only_ratio}.pt"
        

    # 创建数据划分
    train_data, test_data = processor.create_subset_and_split(
        train_size=args.train_size, 
        test_size=args.test_size,
        img_only_ratio=args.img_only_ratio,
        text_only_ratio=args.text_only_ratio
    )
    
    print("Loading pretrained CLIP model...")
    import clip
    clip_model, preprocess = clip.load("ViT-B/32", device=device)

    base_save_path = f'{args.dataset}/{args.train_size}/{args.img_only_ratio}'
    os.makedirs(base_save_path, exist_ok=True)
    os.makedirs(args.img_save_dir, exist_ok=True)
    
    print("Building internal knowledge base...")
    build_knowledge_base(
        data=train_data.get('multimodal'), 
        data_dir = args.data_dir,
        clip_model=clip_model,
        preprocess=preprocess,
        save_path=os.path.join(base_save_path, "internal_kb.pt"),
        device=device
    )

    print("Building external knowledge base from CC3M...")
    external_kb = build_cc3m_knowledge_base(
        clip_model=clip_model,
        preprocess=preprocess,
        kb_dir="./cc3m_kb",
        num_samples=10000,
        num_workers=16,
        device=device
    )

    external_kb = build_flickr30k_knowledge_base(
        clip_model=clip_model,
        preprocess=preprocess,
        kb_dir="./flickr30k_kb",
        device=device,
        num_samples=10000
    )
    
    external_kb = build_nwpu_knowledge_base(
        clip_model=clip_model,
        preprocess=preprocess,
        kb_dir="./nwpu_kb",
        data_dir="./nwpu_data",
        num_samples=10000,
        device="cuda"
    )   

    internal_easy_path, external_easy_path = run_easy_alignment(
        coco_path=os.path.join(base_save_path, "internal_kb.pt"),
        cc3m_path=external_path,
        dataset_name=args.dataset,
        ratio = args.img_only_ratio,
        k_in=128, k_out=512,
        tau_c=0.7, tau_v=0.7,
        n_iter=5, k_csls=50, min_pairs=500,
        ckpt_dir="./checkpoints/align_net",
        out_internal_easy=os.path.join(base_save_path, "internal_kb_easy.pt"),
        out_external_easy=out_external_easy,
        device=str(device)
    )
    
    internal_kb = torch.load(internal_easy_path)
    external_kb = torch.load(external_easy_path)
    params_txt = AlignParams.load(f"./checkpoints/align_net/{args.dataset}/{args.img_only_ratio}/unsup_align_params_txt.pt")
    params_img = AlignParams.load(f"./checkpoints/align_net/{args.dataset}/{args.img_only_ratio}/unsup_align_params_img.pt")

    combined_kb = {
        'text_embeds': torch.cat([internal_kb['text_embeds'], external_kb['text_embeds']], dim=0),
        'image_embeds': torch.cat([internal_kb['image_embeds'], external_kb['image_embeds']], dim=0),
        'meta': internal_kb['meta'] + external_kb['meta'],
    }
    
    print("Training bidirectional MLP...")
    
    mlp_model = train_bidirectional_mlp(
        clip_model, train_data,test_data, args.data_dir, args.dataset,args.img_only_ratio,
        epochs=args.mlp_epochs, device=device
    )
    
    print("Initializing  generator...")
    generator = RAGGenerator(device=device)

    print("Augmenting dataset with generation...")
    
    augmented_data = augment_dataset_img_txt_easy(
        generator, clip_model,mlp_model, preprocess,train_data, args.data_dir, args.img_save_dir,args.data_save_dir,
        combined_kb, params_img,params_txt,device=device, k=args.k,n=args.N, lambda_score=0.5
    )
    augmented_data['multimodal'].extend(train_data['multimodal'])

    print("Training CLIP model...")
    albef_model,results = train_albef(
        augmented_data.get('multimodal'), test_data, args.data_dir, 
        epochs=args.clip_epochs, device=device
    )

    print("Evaluating model...")
    results = evaluate_retrieval_clip(
        albef_model, test_data, args.data_dir, 
        k_values=[1, 5], device=device
    )

    dir_path = f'./evaluation/{args.dataset}/{args.train_size}/{args.img_only_ratio}'
    os.makedirs(dir_path, exist_ok=True)


    save_path=f'{dir_path}/evaluation_results_{args.mode}_{args.N}_{args.k}.json'
    with open(save_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print("Training and evaluation completed!")
    print("Results saved to evaluation_results.json")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Multimodal Learning Pipeline")
    
    dataset = 'MScoco'
    data_dir = './coco_data'
    train_data = 10000
    test_data = 1000
    img_only_ratio = 0.35
    mode = 'our'
    k=10
    N=10
    

    parser.add_argument('--dataset', type=str, default = dataset, 
                        help='dataset')
    parser.add_argument('--data_dir', type=str, default=data_dir, 
                        help='Directory to store dataset')
    parser.add_argument('--img_save_dir', type=str, default=f'./Gen/{dataset}/{train_data}/{img_only_ratio}/Gen_{mode}_{N}_{k}', 
                        help='Directory to store gen image')
    parser.add_argument('--data_save_dir', type=str, default=f'{data_dir}/{train_data}/{img_only_ratio}/aug_data_Gen_{mode}_{N}_{k}.jsonl', 
                        help='Directory to store gen dataset')
    parser.add_argument('--train_size', type=int, default=train_data, 
                        help='Number of training samples')
    parser.add_argument('--test_size', type=int, default=test_data, 
                        help='Number of test samples')
    parser.add_argument('--img_only_ratio', type=float, default=img_only_ratio, 
                        help='Ratio of image-only samples')
    parser.add_argument('--text_only_ratio', type=float, default=img_only_ratio, 
                        help='Ratio of text-only samples')
    parser.add_argument('--mode', type=str, default=mode)
    parser.add_argument('--N', type=int, default=N)
    parser.add_argument('--k', type=int, default=k)
    parser.add_argument('--mlp_epochs', type=int, default=20, 
                        help='Number of epochs for MLP training')
    parser.add_argument('--mapper_epochs', type=int, default=15, 
                        help='Number of epochs for mapper training')
    parser.add_argument('--clip_epochs', type=int, default=20, 
                        help='Number of epochs for clip training')
    parser.add_argument('--seed', type=int, default=42, 
                        help='Random seed')
    
    args = parser.parse_args()
    main(args)
