import os
import json
import argparse
import torch
import torch.backends.cudnn as cudnn
import torchvision.models as models
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from tqdm import tqdm
import pickle
import torch.nn as nn
import torch.distributed as dist
from ibot.models import vit_large
from get_swin_ft import load_swin_l, load_swin_b, load_swin_s, load_swin_t
import timm
from eva.modeling_finetune import mae_vit_large_patch16
from lavis.models import load_model_and_preprocess

def get_args():
    parser = argparse.ArgumentParser('Feature extraction script using customized ResNet50')
    # Data, Model, Output 
    parser.add_argument('--out_dir', type=str, default='patent_feature', help='Directory to save features')
    parser.add_argument('--json_path', type=str, default='', help='Path to JSON file with index')
    parser.add_argument('--model', default='r50', type=str, help='Model for extract feature')
    parser.add_argument('--model_path', type=str, required=False, default=None, help='Path to the model checkpoint')
    parser.add_argument('--image_path', type=str, required=False, help='Path to image paths')
    parser.add_argument('--item_type', type=str, default='product', help='Path to image paths')
    parser.add_argument('--item_num_recordings', type=str, default=None, help='Path to image paths')
    # Input Image, batch
    parser.add_argument('--input_size', default=224, type=int, help='Input size for the model')
    parser.add_argument('--batch_size', default=256, type=int, help='Number of patents to process in parallel')
    parser.add_argument('--num_workers', default=12, type=int, help='Number of workers for DataLoader')
    parser.add_argument('--num_gpus', default=1, type=int, help='Number of GPUs to use')
    # GPU
    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
        distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    return parser.parse_args()

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    elif torch.cuda.is_available():
        print('Will run the code on one GPU.')
        args.rank, args.gpu, args.world_size = 0, 0, 1
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '49500'
    else:
        print('Does not support training without GPU.')
        sys.exit(1)
    # CUDA initialize
    dist.init_process_group(
        backend="nccl",
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
    )

    torch.cuda.set_device(args.gpu)
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    dist.barrier()
    setup_for_distributed(args.rank == 0)

def custom_collate_fn(batch):
    """Filter the images of Nont type. """
    filtered_batch = [(img, path) for img, path in batch if img is not None]
    if not filtered_batch:
        return torch.Tensor(), []
    images, paths = zip(*filtered_batch)
    images = torch.stack(images, dim=0)
    return images, paths

class CustomDataset(Dataset):
    def __init__(self, data_path_file, transform):
        self.image_paths = []
        with open(data_path_file, 'r') as f:
            for line in f:
                self.image_paths.append(line.strip())
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        # Use black image to replace some damaged images 
        try:
            image = Image.open(img_path).convert('RGB')
        except (OSError, IOError) as e:
            print(f"Error loading image {img_path}: {e}")
            with open('damaged_images.txt', 'a') as er:
                er.write(img_path + '\n')
            image = Image.fromarray(np.uint8(np.zeros([224, 224, 3])))
        if self.transform:
            image = self.transform(image)
        return image, img_path

def split_batch_to_single(batch_size, item_nums, item_names, raw_item_path, target_item_path):
    # One patent/product may have more than one images, split features in fixed batch to specific patent/product using index recordings. 
    total_item_num = len(item_nums)
    begin_indexes = [None] * total_item_num
    end_indexes = [None] * total_item_num
    begin_index = 0
    for i, item_num in enumerate(item_nums):
        begin_indexes[i] = begin_index
        end_index = begin_index + item_nums[i]
        end_indexes[i] = end_index
        begin_index = end_index
    record_num = 0
    with open(os.path.join(raw_item_path, str(record_num).zfill(7)+'.pkl'), 'rb') as f:
        recording = pickle.load(f)
    temp = []
    for i, item_name in enumerate(tqdm(item_names)):
        begin_index = begin_indexes[i]
        end_index = end_indexes[i]
        
        record_index_of_begin_index = begin_index // batch_size
        record_pos_of_begin_index = begin_index % batch_size
        record_index_of_end_index = end_index // batch_size
        record_pos_of_end_index = end_index % batch_size

        if record_index_of_begin_index==record_num:
            if record_index_of_end_index==record_num:
                item_feature = recording[record_pos_of_begin_index:record_pos_of_end_index, :]
                if record_pos_of_end_index==(batch_size-1):
                    record_num = record_num + 1
                    with open(os.path.join(raw_item_path, str(record_num).zfill(7)+'.pkl'), 'rb') as f:
                        recording = pickle.load(f)
            elif record_index_of_end_index>record_num:
                temp.append(recording[record_pos_of_begin_index:, :])
                record_num = record_num + 1
                for j in range(record_num, record_index_of_end_index+1):
                    with open(os.path.join(raw_item_path, str(record_num).zfill(7)+'.pkl'), 'rb') as f:
                        recording = pickle.load(f)
                    if j==record_index_of_end_index:
                        temp.append(recording[:record_pos_of_end_index, :])
                    elif j<record_index_of_end_index:
                        temp.append(recording)
                        record_num = record_num + 1
                item_feature = np.concatenate(temp, 0)
                temp = []
        output_path = os.path.join(target_item_path, f"{item_name}.pkl")
        with open(output_path, 'wb') as f:
            pickle.dump(item_feature, f)
                    



def main():
    args = get_args()
    init_distributed_mode(args)
    cudnn.benchmark = True
    transform = Compose([
        Resize((args.input_size, args.input_size)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load model
    if args.model == 'r50':
        model = models.resnet50(pretrained=True)
    elif args.model == 'r18':
        model = models.resnet18(pretrained=True)
    elif args.model == 'r101':
        model = models.resnet101(pretrained=True)
    elif args.model == 'swin_t':
        model = models.swin_t(pretrained=True)
    elif args.model == 'swin_s':
        model = models.swin_s(pretrained=True)
    elif args.model == 'swin_b':
        model = models.swin_b(pretrained=True)
    elif args.model == 'vit':
        model = models.vit_b_16(pretrained=True)
    elif args.model == 'dino':
        model = vit_large(16)      
    elif args.model == 'ibot':
        model = vit_large(16)
    elif args.model == 'clip':
        model, _, _ = load_model_and_preprocess("clip_feature_extractor", model_type="ViT-L-14-336", is_eval=True)
        model = model.visual
    elif args.model == 'mae':
        model = models.vit_b_16(pretrained=True)
    # Load checkpoint if need
    if args.model_path is not None:
        model.load_state_dict(torch.load(args.model_path))

    # Load dataset
    test_dataset = CustomDataset(args.image_path, transform=transform)
    sampler = torch.utils.data.DistributedSampler(test_dataset, shuffle=False)
    test_loader = DataLoader(test_dataset, 
                        batch_size=args.batch_size, 
                        sampler=sampler,
                        num_workers=args.num_workers, 
                        drop_last=False, 
                        shuffle=False, 
                        collate_fn=custom_collate_fn)
 
    model = nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[args.gpu])
    model.eval()
    os.makedirs(args.out_dir, exist_ok=True)
    # Save fixed-size batch features
    raw_feature_path = os.path.join(args.out_dir, args.item_type+'_raw')
    os.makedirs(raw_feature_path, exist_ok=True)
    # Save patent/product-wise features
    itemwise_feature_path = os.path.join(args.out_dir, args.item_type+'_itemwise')
    os.makedirs(itemwise_feature_path, exist_ok=True)

    for i, (images, _) in enumerate(tqdm(test_loader)):
        if images.numel() == 0: 
            continue
        if torch.cuda.is_available():
            images = images.cuda()

        with torch.no_grad():
            features = model(images).cpu().numpy().astype(np.float16)

        output_path = os.path.join(raw_feature_path, str(i).zfill(7)+".pkl")
        with open(output_path, 'wb') as f:
            pickle.dump(features, f)

    if args.item_num_recordings is not None: 
        print('Begin Split Feature from Batch-wise to Item-wise')
        item_num_recordings = pd.read_csv(args.item_num_recordings)
        split_batch_to_single(args.batch_size, 
                            list(item_num_recordings['img_num']), 
                            list(item_num_recordings['Index']), 
                            raw_feature_path, 
                            itemwise_feature_path, )

if __name__ == '__main__':
    main()