import os
import sys
import urllib.request
import zipfile
import tarfile

from PIL import Image
import torch as th
from torchvision import transforms
import torchvision
from torchvision.datasets import ImageFolder
from tqdm import tqdm

from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    model_and_diffusion_defaults_upsampler
)


num_gpu = 4


def download_and_extract(url, download_dir, extract_dir):
    # Check if the dataset has already been extracted
    if os.path.exists(extract_dir):
        print(f"Dataset already exists at {extract_dir}, skipping download and extraction.")
        return

    # Download the dataset
    file_name = os.path.join(download_dir, os.path.basename(url))
    if not os.path.exists(file_name):
        print(f"Downloading {file_name}...")
        urllib.request.urlretrieve(url, file_name)
        print("Download complete.")
    
    # Extract the dataset
    print(f"Extracting {file_name}...")
    if file_name.endswith(".zip"):
        with zipfile.ZipFile(file_name, 'r') as zip_ref:
            zip_ref.extractall(download_dir)
    elif file_name.endswith(".tar.gz") or file_name.endswith(".tgz"):
        with tarfile.open(file_name, 'r:gz') as tar_ref:
            tar_ref.extractall(download_dir)
    print("Extraction complete.")

def load_tiny_imagenet_labels(words_file_path):
    """Load Tiny ImageNet WordNet IDs and their corresponding label names."""
    label_map = {}
    with open(words_file_path, 'r') as f:
        for line in f:
            wnid, label_name = line.strip().split('\t')
            label_map[wnid] = label_name.split(',')[0]  # Take the first label if there are multiple
    return label_map

def save_images(dataset_name='CIFAR10', dataset_dir=None, output_dir='./images'):
    # Select loading method based on dataset name
    if dataset_name == 'CIFAR10':
        dataset_dir = dataset_dir or './data/cifar-10-batches-py'
        if not os.path.exists(dataset_dir):
            dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
        else:
            print(f"{dataset_name} dataset already exists at {dataset_dir}, skipping download.")
            dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor())
        label_names = dataset.classes
        data = dataset.data
        labels = dataset.targets

    elif dataset_name == 'CIFAR100':
        dataset_dir = dataset_dir or './data/cifar-100-python'
        if not os.path.exists(dataset_dir):
            dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())
        else:
            print(f"{dataset_name} dataset already exists at {dataset_dir}, skipping download.")
            dataset = torchvision.datasets.CIFAR100(root='./data', train=True, transform=transforms.ToTensor())
        label_names = dataset.classes
        data = dataset.data
        labels = dataset.targets

    elif dataset_name == 'TinyImageNet':
        if not dataset_dir:
            dataset_dir = './data/tiny-imagenet-200'
            download_and_extract(
                url='http://cs231n.stanford.edu/tiny-imagenet-200.zip',
                download_dir='./data',
                extract_dir=dataset_dir
            )
    
        label_map = load_tiny_imagenet_labels(os.path.join(dataset_dir, 'words.txt'))

        # Create the output directory
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        sorted_labels = []
        # Process the training images
        train_dir = os.path.join(dataset_dir, 'train')
        for wnid in os.listdir(train_dir):
            wnid_dir = os.path.join(train_dir, wnid, 'images')
            if os.path.isdir(wnid_dir) and wnid in label_map:
                label_name = label_map[wnid]
                sorted_labels.append(label_name)
                label_dir = os.path.join(output_dir, label_name)
                if not os.path.exists(label_dir):
                    os.makedirs(label_dir)

                for image_file in os.listdir(wnid_dir):
                    image_path = os.path.join(wnid_dir, image_file)
                    if image_path.endswith(".JPEG"):
                        img = Image.open(image_path).convert('RGB')
                        img.save(os.path.join(label_dir, image_file))

        # Process the validation images
        val_annotations_path = os.path.join(dataset_dir, 'val', 'val_annotations.txt')
        val_images_dir = os.path.join(dataset_dir, 'val', 'images')

        with open(val_annotations_path, 'r') as f:
            for line in f:
                tokens = line.split('\t')
                image_file = tokens[0]
                wnid = tokens[1]
                if wnid in label_map:
                    label_name = label_map[wnid]
                    label_dir = os.path.join(output_dir, label_name)
                    if not os.path.exists(label_dir):
                        os.makedirs(label_dir)

                    img = Image.open(os.path.join(val_images_dir, image_file)).convert('RGB')
                    img.save(os.path.join(label_dir, image_file))

        print(f"Tiny ImageNet images saved in human-readable directories at {output_dir}.")
        
        sorted_labels=sorted(sorted_labels)
        print(sorted_labels)
        return sorted_labels

    elif dataset_name == 'CINIC10':
        if not dataset_dir:
            dataset_dir = './data/cinic-10'
            download_and_extract(
                url='https://datashare.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz',
                download_dir='./data',
                extract_dir=dataset_dir
            )
        dataset = ImageFolder(root=os.path.join(dataset_dir, 'train'), transform=transforms.ToTensor())
        label_names = sorted(dataset.classes)
        data = [image for image, _ in dataset]
        labels = [label for _, label in dataset.imgs]

        # Create directory for saving images
        for label_name in label_names:
            label_dir = os.path.join(output_dir, label_name)
            if not os.path.exists(label_dir):
                os.makedirs(label_dir)

        # Save images to corresponding subdirectory
        for i, (image, label) in enumerate(zip(data, labels)):
            label_name = label_names[label]
            image = transforms.ToPILImage()(image)
            image_path = os.path.join(output_dir, label_name, f'{i}.png')
            image.save(image_path)

        print(f"{dataset_name} images saved, label array generated.")
        print(label_names)
        return label_names

    else:
        raise ValueError("Invalid dataset name. Please choose 'CIFAR10', 'CIFAR100', 'TinyImageNet', or 'CINIC10'.")

    # Create directory for saving images
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Create subdirectories for each label
    for label_name in label_names:
        label_dir = os.path.join(output_dir, label_name)
        if not os.path.exists(label_dir):
            os.makedirs(label_dir)

    # Save images to corresponding subdirectory
    for i, (image, label) in enumerate(zip(data, labels)):
        label_name = label_names[label]
        image = Image.fromarray(image)
        image_path = os.path.join(output_dir, label_name, f'{i}.png')
        image.save(image_path)

    print(f"{dataset_name} images saved, label array generated.")
    print(label_names)
    return label_names

def main(id, dataset, ref_img_path, save_path):

    has_cuda = th.cuda.is_available()
    device = th.device('cpu' if not has_cuda else 'cuda')

    # Create base model.
    options = model_and_diffusion_defaults()
    options['use_fp16'] = has_cuda
    options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
    model, diffusion = create_model_and_diffusion(**options)
    model.eval()
    if has_cuda:
        model.convert_to_fp16()
    model.to(device)
    model.load_state_dict(load_checkpoint('base', device))
    print('total base parameters', sum(x.numel() for x in model.parameters()))

    # Create upsampler model.
    options_up = model_and_diffusion_defaults_upsampler()
    options_up['use_fp16'] = has_cuda
    options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
    model_up, diffusion_up = create_model_and_diffusion(**options_up)

    model_up.eval()
    if has_cuda:
        model_up.convert_to_fp16()
    model_up.to(device)
    model_up.load_state_dict(load_checkpoint('upsample', device))
    print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))


    def save_images_multi(batch: th.Tensor, save_path=['1.png', '2.png']):
        """ Display a batch of images inline. """
        scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu() # B 3 H W
        reshaped = scaled.permute(0, 2, 3, 1) # B H W 3
        for i,im in enumerate(reshaped):
            if save_path[i] == save_path[-1]:
                Image.fromarray(im.numpy()).save(save_path[i])

    batch_size = 1
    batch_size_time = 1
    refer_img_iters = 50
    guidance_scale = 3.0
    # Tune this parameter to control the sharpness of 256x256 images.
    # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
    upsample_temp = 0.997
    
    label_names = save_images(dataset, output_dir=ref_img_path)

    eurosat_names_coop = label_names
        #  'a Industrial Buildings', 'a Pasture Land', 'a Permanent Crop Land', 'a Residential Buildings', 'a River',
        #  'a Sea or Lake']

    def get_few_shot_images_path_prompt_pairs(root='/path/to/few-shot/images', prompts=None):
        import glob
        path_prompt_pairs = []
        cls = sorted(glob.glob(root+'/*'))
        # print(cls)
        for i in range(len(prompts)):
            c = cls[i]
            ims_i = sorted(glob.glob(c+'/*'))
            # print(ims_i)
            for im in ims_i:
                path_prompt_pairs.append([im, prompts[i]])
        return path_prompt_pairs

    prompts = ['a photo of '+v[:] for v in eurosat_names_coop]
    # print(prompts)
    path_prompt_pairs = get_few_shot_images_path_prompt_pairs(root=ref_img_path, prompts=prompts)
    # print(path_prompt_pairs)

    total_len = len(path_prompt_pairs)
    prompt_list = path_prompt_pairs

    if total_len % num_gpu == 0:
        each_len = total_len // num_gpu
    else:
        each_len = total_len // num_gpu +1

    if id != num_gpu-1:
        prompt_list = prompt_list[id*each_len:(id+1)*each_len]
        print('GPU {}: {}-{}'.format(id, id*each_len,(id+1)*each_len))
    else:
        prompt_list = prompt_list[id * each_len:]
        print('GPU {}: {}-{}'.format(id, id * each_len,(id+1)*each_len))

    def text2image(prompt, batch_size, img=None):


        ##############################
        # Sample from the base model #
        ##############################

        # Create the text tokens to feed to the model.
        tokens = model.tokenizer.encode(prompt)
        tokens, mask = model.tokenizer.padded_tokens_and_mask(
            tokens, options['text_ctx']
        )

        # Create the classifier-free guidance tokens (empty)
        full_batch_size = batch_size * 2
        uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
            [], options['text_ctx']
        )

        # Pack the tokens together into model kwargs.
        model_kwargs = dict(
            tokens=th.tensor(
                [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
            ),
            mask=th.tensor(
                [mask] * batch_size + [uncond_mask] * batch_size,
                dtype=th.bool,
                device=device,
            ),
        )

        # Create a classifier-free guidance sampling function
        def model_fn(x_t, ts, **kwargs):
            half = x_t[: len(x_t) // 2]
            combined = th.cat([half, half], dim=0)
            model_out = model(combined, ts, **kwargs)
            eps, rest = model_out[:, :3], model_out[:, 3:]
            cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
            half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
            eps = th.cat([half_eps, half_eps], dim=0)
            return th.cat([eps, rest], dim=1)

        # Sample from the base model.
        model.del_cache()
        samples = diffusion.p_sample_loop(
            model_fn,
            (full_batch_size, 3, options["image_size"], options["image_size"]),
            device=device,
            clip_denoised=True,
            progress=False,
            model_kwargs=model_kwargs,
            cond_fn=None,
            flag_refer_img_iters=refer_img_iters,
            noise=img,
        )[:batch_size]
        model.del_cache()

        ##############################
        # Upsample the 64x64 samples #
        ##############################

        tokens = model_up.tokenizer.encode(prompt)
        tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
            tokens, options_up['text_ctx']
        )

        # Create the model conditioning dict.
        model_kwargs = dict(
            # Low-res image to upsample.
            low_res=((samples + 1) * 127.5).round() / 127.5 - 1,

            # Text tokens
            tokens=th.tensor(
                [tokens] * batch_size, device=device
            ),
            mask=th.tensor(
                [mask] * batch_size,
                dtype=th.bool,
                device=device,
            ),
        )

        # Sample from the base model.
        model_up.del_cache()
        up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
        up_samples = diffusion_up.ddim_sample_loop(
            model_up,
            up_shape,
            noise=th.randn(up_shape, device=device) * upsample_temp,
            device=device,
            clip_denoised=True,
            progress=False,
            model_kwargs=model_kwargs,
            cond_fn=None,
        )[:batch_size]
        model_up.del_cache()

        return up_samples

    def read_img_from_path_to_tensor(path):
        img = Image.open(path)
        tsfm = transforms.Compose([
            transforms.Resize(64, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ]
        )
        img = tsfm(img).unsqueeze(0).cuda()
        return img

    i=0
    for [ref_path, prompt] in tqdm(prompt_list):
        i+=1
        # print(f"{i}/{len(prompt_list)}")
        label_name = prompt.split(' ')[-1]
        save_path_label = os.path.join(save_path, label_name)
        if not os.path.exists(save_path_label):
            os.makedirs(save_path_label)
        for b in range(batch_size_time):
            ref_img = read_img_from_path_to_tensor(ref_path)
            up_samples = text2image(prompt, batch_size, ref_img)
            save_paths = [save_path_label + '/' +
                          prompt.replace(' ', '_').replace('/', '_') + '_' + ref_path.split('/')[-1][:-4] + '_' +str(b)+'_' + str(i) + '.png' for i in
                          range(batch_size)]
            save_images_multi(up_samples, save_path=save_paths)


if __name__ == "__main__":
    import os
    import sys
    id = sys.argv[1]
    dataset = sys.argv[2]
    ref_img_path = sys.argv[3]
    save_path = sys.argv[4]
    print(sys.argv[1])
    id = int(id)

    main(id, dataset, ref_img_path, save_path)







