import glob
import lpips
import math
import os

import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import torch
from tqdm import tqdm

from code.exp_utils import load_base_model


def gather_results_vi_old(*, target: str, temp: float, num_samples: int):
    dirs = glob.glob(os.path.join(target, 'index=*'))
    pbar = tqdm(dirs)
    samples = []
    indices = []
    for d in pbar:
        fn = os.path.join(d, f'samples_temp={temp}.npy')
        if not os.path.exists(fn):
            print(f'Skipping {fn} -- does not exist')
            continue
        index = int(d.split('/')[-1].split('=')[1])
        indices.append(index)

        pbar.set_description(f'Processing {d}')
        samples.append(np.load(fn)[:num_samples])
        assert len(samples[-1]) == num_samples

    samples = np.stack(samples, axis=0)
    return samples, indices


def gather_results_vi(*, target: str, step: int, temp: float, num_samples: int, strict: bool=True):
    dirs = sorted(glob.glob(os.path.join(target, 'index=*')))
    pbar = tqdm(dirs)
    samples = []
    indices = []
    for d in pbar:
        fn = os.path.join(d, f'samples.pt')
        pbar.set_description(f'Processing {d}')
        index = int(d.split('/')[-1].split('=')[1])
        indices.append(index)

        if strict:
            assert os.path.exists(fn)
        elif not os.path.exists(fn):
            print(f'Skipping {fn} -- does not exist')
            continue

        dd = torch.load(fn)
        x = dd[step][temp][:num_samples].clone()
        x = x.permute(0, 2, 3, 1).numpy()

        samples.append(x)
        assert len(samples[-1]) == num_samples

    samples = np.stack(samples, axis=0) * 256.
    samples = samples.clip(0, 255.999)
    assert samples.min () >= 0.0 and samples.max() <= 256.0 and samples.mean() >= 10.0
    samples = samples.astype(np.uint8)

    return samples, indices


def gather_results_mcmc_old(*, target: str, num_samples: int):
    samples = []
    batch_count = None
    for sample_idx in range(num_samples):
        dirs = sorted(glob.glob(os.path.join(target, f'samples_{sample_idx:03d}_indices_*')))
        pbar = tqdm(dirs)
        xs = []
        for d in pbar:
            fn = os.path.join(d, f'sample_x.pt')
            assert os.path.exists(fn), f'{fn} does not exist'
            pbar.set_description(f'Processing {os.path.basename(d)}')
            xs.append(torch.load(fn))

        xs = torch.cat(xs, dim=0)
        if batch_count is None:
            batch_count = len(xs)
        else:
            assert batch_count == len(xs), f'Expected: {batch_count}   Actual: {len(xs)}   for sample {sample_idx}'
        samples.append(xs)

    samples = torch.stack(samples, axis=1) * 256.
    samples = samples.cpu().numpy().astype(np.uint8).transpose(0, 1, 3, 4, 2)

    return samples


def gather_results_mcmc(*, target: str, step: int, num_samples:int, max_index=None, strict: bool=True):
    samples = []
    batch_size = None
    for sample_idx in range(num_samples):
        dirs = sorted(glob.glob(os.path.join(target, f'samples_{sample_idx:03d}_indices_*')))
        xs = []
        for d in dirs:
            print(f'\rProcessing {os.path.basename(d)} ... ', end='', flush=True)
            fn = os.path.join(d, f'samples.pt')
            if strict:
                assert os.path.exists(fn)
            elif not os.path.exists(fn):
                print(f'Skipping -- samples do not exist')
                continue

            xs.append(torch.load(fn)[step])
        print()

        if strict:
            assert len(xs) > 0
        elif len(xs) == 0:
            print(f'Sample {sample_idx} batch is empty; continuing.')
            continue

        xs = torch.cat(xs, dim=0)
        if max_index is not None and len(xs) > max_index:
            print(f'Reducing sample {sample_idx} batch size of {max_index}')
            xs = xs[:max_index]

        if batch_size is None:
            batch_size = len(xs)
        elif strict:
            assert batch_size == len(xs), f'Expected: {batch_size}   Actual: {len(xs)}'
        print(f'Sample {sample_idx} batch size: {len(xs)}')
        samples.append(xs)

    samples = torch.stack(samples, dim=1) * 256.
    samples = samples.clamp(0, 255.9)
    assert samples.ndim == 5
    samples = samples.cpu().numpy().astype(np.uint8).transpose(0, 1, 3, 4, 2)

    return samples


@torch.no_grad()
def gather_results_base(*, target: str, num_samples: int, base_model: str, base_ckpt: str, temp: float):
    samples = None
    fn = os.path.join(target, f'samples.pt')
    if os.path.exists(fn):
        samples = torch.load(fn)
        if len(samples) < num_samples:
            samples = None

    if samples is None:
        samples = []
        print('Generating samples from base model...')
        base_model = load_base_model(base_model, base_ckpt)
        base_model.cuda()
        n_batches = math.ceil(num_samples / 64)
        for i in range(n_batches):
            x = base_model.sample(64, temp=temp, device=torch.device('cuda:0'))
            samples.append(x.cpu())
        samples = torch.cat(samples, dim=0)
        samples = (samples * 256.).clamp(0, 255.9).type(torch.uint8)
        torch.save(samples, fn)

    samples = samples[:num_samples]
    samples = samples.cpu().numpy().transpose(0, 2, 3, 1)[:, None]
    assert len(samples) == num_samples and samples.dtype == np.uint8 and samples.ndim == 5

    return samples


def batched_lpips(dist_fn: str, x, y, bs=64):
    if dist_fn == 'alex':
        dist_fn = lpips.LPIPS(net='alex').cuda()
    elif dist_fn == 'vgg':
        dist_fn = lpips.LPIPS(net='vgg').cuda()
    else:
        raise ValueError(f'Invalid dist_fn: {dist_fn}')

    assert x.ndim == y.ndim == 4
    assert x.shape == y.shape
    n_batches = math.ceil(len(x) / bs)
    dists = []
    for i in range(n_batches):
        bx = x[i*bs : (i+1)*bs]
        by = y[i*bs : (i+1)*bs]
        dists.append(dist_fn(bx,by))
    dists = torch.cat(dists, dim=0).squeeze()
    assert len(dists) == len(x)
    return dists


