import contextlib
import fire
import mup
import numpy as np
from transformers import AutoTokenizer

import lib.datasets
import lib.models
import lib.utils
import os
import time
import torch
import torch.nn.functional as F
import tqdm
from torch import nn, optim, autograd
import wandb
from datasets import load_dataset
from einops import rearrange


def main(**args):
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    args = lib.utils.AttributeDict(args)
    args.setdefault('seq_len', 256)
    args.setdefault('seed', 42)
    args.setdefault('vocab_size', 32768)
    args.setdefault('weights_path', None)
    args.setdefault('dim', 2048)
    args.setdefault('n_blocks', 24)
    args.setdefault('n_heads', 32)
    args.setdefault('gamma_0', -3.)
    args.setdefault('gamma_1', 6.)
    args.setdefault('embed_dim', 16)
    args.setdefault('initial_noise_scale', 1.0)
    args.setdefault('sampling_timesteps', 4096)
    args.setdefault('score_temp', 0.9)
    args.setdefault('output_scale', 1.)
    args.setdefault('owt2_tokenizer', True)
    args.setdefault('guidance_weight', 2.)
    args.setdefault('seq_len', 1024)
    args.setdefault('sampler', 'ddpm')
    args.setdefault('wandb', False)
    args.setdefault('num_samples', 5000)
    args.setdefault('prefix_length', 32)
    args.setdefault('batch_size', 256)
    args.setdefault('_wandb', {})
    args.setdefault('log_metrics', False)
    lib.utils.print_args(args)
    wandb.init(resume=True, project="PROJECT_NAME", config=vars(args))
    callback = None
    if args.wandb:
        print('using wandb')
        callback = wandb
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_default_device('cuda')

    # Lots of annoying big/small numbers throughout this code, so we'll do
    # everything in fp64 by default and explicitly switch to fp32/bf16 where
    # appropriate.
    torch.set_default_dtype(torch.float64)
    torch.manual_seed(args.seed)

    def log1mexp(x):
        # Computes log(1-exp(-|x|))
        x = -x.abs()
        return torch.where(
            x > -0.693,
            torch.log(-torch.expm1(x)),
            torch.log1p(-torch.exp(x))
        )

    def create_modules(dim, n_heads):
        return {
            'noise_schedule': lib.models.NoiseSchedule().float(),
            'gamma_bounds': lib.models.GammaBounds(args.gamma_0, args.gamma_1).float(),
            'embedding_matrix': lib.models.EmbeddingMatrix(args.vocab_size, args.embed_dim).float(),
            'model': lib.models.DiffusionModel(dim, args.embed_dim, args.n_blocks, n_heads, args.vocab_size).float()
        }

    modules = create_modules(args.dim, args.n_heads)
    base_modules = create_modules(256, 4)
    delta_modules = create_modules(128, 2)
    for key in modules:
        main, base, delta = modules[key], base_modules[key], delta_modules[key]
        mup.set_base_shapes(main, base, delta=delta)
        main.cuda()

    print(f'Loading weights from {args.weights_path}')
    for name, module in modules.items():
        module.load_state_dict(torch.load(
            os.path.join(args.weights_path, f'{name}.pt'),
            map_location=torch.device('cuda')
        ))

    for key in modules:
        print(key + ':')
        lib.utils.print_model(modules[key])

    def append_dims(x, target_dims):
        """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
        dims_to_append = target_dims - x.ndim
        if dims_to_append < 0:
            raise ValueError(
                f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
            )
        return x[(...,) + (None,) * dims_to_append]

    def to_d(x, sigma, denoised):
        """Converts a denoiser output to a Karras ODE derivative."""
        return (x - denoised) / append_dims(sigma, x.ndim)

    def generate_samples(guidance_tokens, seq_len=args.seq_len, batch_size=args.batch_size, callback=None, prefix='',
                         seed: int = 42):
        """
        Sampling (implements Appendix A.4 eqn 33 in VDM). Needs float64 to work.
        guidance_tokens: [(token, weight, position, complement), ...]
            token: vocab index of token
            weight: guidance weight
            position: sequence index, or 'any', or 'all'
            complement: if True, do guidance on log(1-p(y|x))
        """
        with torch.no_grad():
            embedding_matrix = modules['embedding_matrix']()

            gamma_0, gamma_1 = modules['gamma_bounds']()
            alpha_0 = torch.sigmoid(-gamma_0).sqrt()
            sigma_0 = torch.sigmoid(gamma_0).sqrt()
            torch.manual_seed(seed)

            z = torch.randn((batch_size, seq_len, args.embed_dim), device='cuda') * args.initial_noise_scale
            x_selfcond = torch.zeros_like(z).float()
            entropy_accumulator = []
            patience_accumulator = []
            kl_accumulator = []
            tokens_accumulator = []
            patience_counter = torch.zeros(batch_size, dtype=torch.long, device=z.device)
            prev_tokens = None
            prev_log_distribution = None
            iter = tqdm.tqdm(enumerate(torch.linspace(1., 0., args.sampling_timesteps)), leave=False, desc='Sampling steps')
            for i, t in iter:
                t = t[None].cuda()
                s = t - 1. / args.sampling_timesteps
                gamma_s = modules['noise_schedule'](s).double()
                gamma_t = modules['noise_schedule'](t).double()
                gamma_s = gamma_0 + (gamma_1 - gamma_0) * gamma_s
                gamma_t = gamma_0 + (gamma_1 - gamma_0) * gamma_t
                alpha_squared_s = torch.sigmoid(-gamma_s)
                alpha_squared_t = torch.sigmoid(-gamma_t)
                alpha_s = alpha_squared_s.sqrt()
                alpha_t = alpha_squared_t.sqrt()
                sigma_squared_s = torch.sigmoid(gamma_s)
                sigma_squared_t = torch.sigmoid(gamma_t)
                sigma_s = sigma_squared_s.sqrt()
                sigma_t = sigma_squared_t.sqrt()

                if len(guidance_tokens) > 0:
                    with torch.enable_grad():
                        z.requires_grad = True
                        logits, x_reconst = modules['model'](
                            z=z.to(torch.float32, copy=True),
                            gamma=gamma_t.float(),
                            embedding_matrix=embedding_matrix,
                            bias_scale=1.,
                            x_selfcond=x_selfcond
                        )

                        logprobs = F.log_softmax(logits.float(), dim=2)
                        logprobs_any = logprobs.logsumexp(dim=1) - float(seq_len)

                        sum_logp = 0.
                        if isinstance(guidance_tokens[0], int):
                            for token, weight, position, complement in guidance_tokens:
                                if position == 'any':
                                    logp = logprobs_any[:, token]
                                elif position == 'all':
                                    logp = logprobs[:, :, token]
                                else:
                                    logp = logprobs[:, position, token]
                                if complement:
                                    logp = log1mexp(logp)
                                sum_logp += weight * logp.sum()
                        else:
                            assert len(guidance_tokens) == batch_size
                            for batch_idx, c_guidance_tokens in enumerate(guidance_tokens):
                                for token, weight, position, complement in c_guidance_tokens:
                                    if position == 'any':
                                        logp = logprobs_any[batch_idx, token]
                                    elif position == 'all':
                                        logp = logprobs[batch_idx, :, token]
                                    else:
                                        logp = logprobs[batch_idx, position, token]
                                    if complement:
                                        logp = log1mexp(logp)
                                    sum_logp += weight * logp.sum()

                        guidance_grad = autograd.grad(sum_logp, [z])[0]
                        z.requires_grad = False
                    x_selfcond = x_reconst.clone().detach()
                    x_reconst = x_reconst.double()
                    epsilon_pred = (z - (alpha_t * x_reconst)) / sigma_t
                    epsilon_pred /= args.score_temp
                    x_reconst = (z - (sigma_t * epsilon_pred)) / alpha_t
                    x_reconst += guidance_grad.double() * sigma_squared_t / alpha_squared_t.sqrt()
                    epsilon_pred = (z - (alpha_t * x_reconst)) / sigma_t
                else:
                    logits, x_reconst = modules['model'](
                        z=z.to(torch.float32, copy=True),
                        gamma=gamma_t.float(),
                        embedding_matrix=embedding_matrix,
                        bias_scale=1.,
                        x_selfcond=x_selfcond
                    )
                    x_selfcond = x_reconst.clone().detach()
                    x_reconst = x_reconst.double()
                    epsilon_pred = (z - (alpha_t * x_reconst)) / sigma_t
                    epsilon_pred /= args.score_temp
                    x_reconst = (z - (sigma_t * epsilon_pred)) / alpha_t
                if t > 0:
                    if args.sampler == "ddim":
                        z = (alpha_s * x_reconst) + (sigma_s * epsilon_pred)
                    elif args.sampler == "euler":
                        dz = to_d(z, sigma_s, x_reconst)
                        dt = sigma_s - sigma_t
                        z = z + dt * dz
                    elif args.sampler == "ddpm":
                        c = -torch.expm1(gamma_s - gamma_t)
                        z *= (1 - c) * alpha_squared_s.sqrt() / alpha_squared_t.sqrt()
                        z += c * (alpha_squared_s.sqrt() * x_reconst.double())
                        z += (c * (1 - alpha_squared_s)).sqrt() * torch.randn_like(z)
                    else:
                        raise TypeError(f"Sampler {args.sampler} not understood")
                if callback is not None:
                    c_tokens = logits.detach().clone().argmax(-1)
                    if prev_tokens is not None:
                        mask = (prev_tokens == c_tokens)
                        batch_mask = mask.sum(-1) == c_tokens.size(1)  # match all elements of the sequence
                        patience_counter += batch_mask
                        patience_counter[~batch_mask] = 0
                        patience_accumulator.append(patience_counter.detach().cpu().numpy())

                    log_p = torch.log_softmax(logits.float(), dim=-1)
                    divergence = None
                    if prev_log_distribution is not None:
                        divergence = torch.nn.functional.kl_div(
                            prev_log_distribution,
                            log_p,
                            log_target=True, reduction='batchmean'
                        )
                        kl_accumulator += [
                            [
                                F.kl_div(prev_log_distribution[i], log_p[i], log_target=True,
                                         reduction='batchmean').detach().cpu().numpy()
                                for i in range(len(log_p))
                            ]
                        ]
                    prev_log_distribution = log_p.detach().clone()
                    prev_tokens = c_tokens
                    entropy = - (log_p * torch.exp(log_p)).sum(-1)
                    callback.log({
                        "entropy": entropy.mean(1).detach().cpu().numpy(),
                        "patience": patience_counter.detach().cpu().numpy(),
                        "kl": None if i == 0 else divergence.detach().cpu().numpy()
                    })
                    iter.set_description(f'entropy: {entropy.mean(1).detach().cpu().numpy().mean()}')
                    entropy_accumulator += [entropy.detach().cpu().numpy()]
                    patience_accumulator += [patience_counter.detach().cpu().numpy()]
                    tokens_accumulator += [
                        log_p.argmax(-1).detach().cpu().numpy()
                    ]

            if callback is not None:
                import pickle
                with open(f'{prefix}_metrics.pkl', 'wb') as outp:
                    metrics = {
                        'entropy': entropy_accumulator,
                        'patience': patience_accumulator,
                        'kl': kl_accumulator,
                        'tokens': tokens_accumulator
                    }
                    pickle.dump(metrics, outp)
                wandb.save(f'{prefix}_metrics.pkl')
            logits, _ = modules['model'](
                z=z.float(),
                gamma=gamma_t.float(),
                embedding_matrix=embedding_matrix,
                bias_scale=1.,
                x_selfcond=x_selfcond
            )
            x_samples = logits.argmax(dim=-1)
            return x_samples

    def log_to_file_samples(x_samples, fn: str = "generated.pkl"):
        import pickle
        owt2_tokenizer = lib.datasets.openwebtext2_tokenizer()
        decoded = [owt2_tokenizer.decode(x, skip_special_tokens=False) for x in x_samples]
        with open(fn, "wb") as outp:
            pickle.dump(decoded, outp)

    def print_samples(x_samples):
        if args.owt2_tokenizer:
            owt2_tokenizer = lib.datasets.openwebtext2_tokenizer()
            for x in x_samples:
                x = owt2_tokenizer.decode(x.tolist(), skip_special_tokens=False)
                print(x.replace("\n", "↵"))
        else:
            for x in x_samples:
                x = x.tolist()
                x = [idx2word[i].decode('utf-8', 'ignore') for i in x]
                x = ' '.join(x)
                x = x.replace('START', '')
                x = x.replace('END', '')
                x = x.replace('PAD', '')
                x = x.replace(' .', '.')
                x = x.replace(' !', '!')
                x = x.replace(' ,', ',')
                x = x.replace(' \' ', '\'')
                x = x.strip()
                # replace newlines with '↵' symbol for cleaner printing
                print(x.replace("\n", "↵"))

    tokenizer = lib.datasets.openwebtext2_tokenizer()

    if args.prefix_length == 0:
        print('Unconditional:')
        x_samples = []
        for i in tqdm.trange(args.num_samples // args.batch_size):
            x_samples += generate_samples(
                [],
                seq_len=args.seq_len,
                callback=callback,
                batch_size=args.batch_size if (i + 1) * args.batch_size < args.num_samples \
                    else args.num_samples - i * args.batch_size
            ).detach().cpu().tolist()

        log_to_file_samples(x_samples, f"unconditional_seed_{args.seed}.pkl")
    else:
        dataset = load_dataset("allenai/c4", data_files=["en/c4-validation.00000-of-00008.json.gz"])
        dataset = dataset.remove_columns(["timestamp", "url"])["train"]
        full_texts = dataset[:args.num_samples]["text"]
        ddlm_tokenizer = AutoTokenizer.from_pretrained("elephantmipt/c-tokenizer")
        prefixes = [ddlm_tokenizer.decode(ddlm_tokenizer.encode(t)[:args.prefix_length]) for t in full_texts]
        # full_texts = [ddlm_tokenizer.decode(ddlm_tokenizer.encode(t)[:args.max_length]) for t in full_texts]
        prompts_ids = [
            [
                (token, args.guidance_weight, position, False) for position, token in
                enumerate(tokenizer.encode(prefix).ids)
            ] for prefix in prefixes
        ]
        x_samples = []
        for i in tqdm.trange(args.num_samples // args.batch_size):
            x_samples += generate_samples(
                prompts_ids[i:i + args.batch_size],
                seq_len=args.seq_len,
                batch_size=len(prompts_ids[i:i + args.batch_size]),
                callback=callback,
                prefix=str(i),
            ).detach().cpu().tolist()
        log_to_file_samples(x_samples, f"conditional_seed_{args.seed}.pkl")
    wandb.save("*.pkl")

    # print('Infilling: A year ago in Paris, [...] Wow, what a great day!')
    # tokenizer = lib.datasets.openwebtext2_tokenizer()
    # prefix = tokenizer.encode(' A year ago in Paris,').ids
    # suffix = tokenizer.encode('. Wow, what a great day!').ids
    # infill_len = 40
    # print_samples(generate_samples(
    #     [(token, args.guidance_weight, position, False) for position, token in enumerate(prefix)]
    #     + [(token, args.guidance_weight, position + len(prefix) + infill_len, False) for position, token in enumerate(suffix)],
    #     callback=callback
    # ))
    # print("\n"*10)

    # print('Word-level weights: Let\'s talk about law[10] and medicine[1].')
    # guidance = [
    #     (tokenizer.encode(' Let').ids,      args.guidance_weight,   0,  False),
    #     (tokenizer.encode('\'s').ids,       args.guidance_weight,   1,  False),
    #     (tokenizer.encode(' talk').ids,     args.guidance_weight,   2,  False),
    #     (tokenizer.encode(' about').ids,    args.guidance_weight,   3,  False),
    #     (tokenizer.encode(' law').ids,      10.,                    4,  False),
    #     (tokenizer.encode(' and').ids,      args.guidance_weight,   5,  False),
    #     (tokenizer.encode(' medicine').ids, args.guidance_weight,   6,  False),
    #     (tokenizer.encode('.').ids,         args.guidance_weight,   7,  False),
    # ]
    # assert(all(len(a) == 1 for a,_,_,_ in guidance))
    # guidance = [(a[0], b, c, d) for a,b,c,d in guidance]
    # print_samples(generate_samples(guidance))
    # print('\n'*10)
    #
    # print('Word-level weights: Let\'s talk about law[1] and medicine[10].')
    # guidance = [
    #     (tokenizer.encode(' Let').ids,      args.guidance_weight,   0,  False),
    #     (tokenizer.encode('\'s').ids,       args.guidance_weight,   1,  False),
    #     (tokenizer.encode(' talk').ids,     args.guidance_weight,   2,  False),
    #     (tokenizer.encode(' about').ids,    args.guidance_weight,   3,  False),
    #     (tokenizer.encode(' law').ids,      args.guidance_weight,   4,  False),
    #     (tokenizer.encode(' and').ids,      args.guidance_weight,   5,  False),
    #     (tokenizer.encode(' medicine').ids, 10.,                    6,  False),
    #     (tokenizer.encode('.').ids,         args.guidance_weight,   7,  False),
    # ]
    # assert(all(len(a) == 1 for a,_,_,_ in guidance))
    # guidance = [(a[0], b, c, d) for a,b,c,d in guidance]
    # print_samples(generate_samples(guidance))
    # print('\n'*10)
    #
    # print(f'Lexically constrained generation: Donald')
    # guidance = [
    #     (tokenizer.encode(' Donald').ids, 3., 'any', False),
    # ]
    # assert(all(len(a) == 1 for a,_,_,_ in guidance))
    # guidance = [(a[0], b, c, d) for a,b,c,d in guidance]
    # print_samples(generate_samples(guidance))
    # print("\n"*10)
    #
    # print(f'Negation: Donald but not Trump')
    # guidance = [
    #     (tokenizer.encode(' Donald').ids, 3., 'any', False),
    #     (tokenizer.encode(' Trump').ids, 10., 'all', True),
    # ]
    # assert(all(len(a) == 1 for a,_,_,_ in guidance))
    # guidance = [(a[0], b, c, d) for a,b,c,d in guidance]
    # print_samples(generate_samples(guidance))
    # print("\n"*10)


if __name__ == '__main__':
    fire.Fire(main)
