import os, sys, time, glob, random, argparse
import random
import numpy as np
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from tqdm import tqdm
import scipy.stats as stats
import torch
from torch import nn
from torch.distributions import Categorical, Distribution
from procedures import TEG
from procedures   import prepare_seed, prepare_logger
from pdb import set_trace as bp
from models import CellStructure, Transformer, count_matmul, matmul, SEARCH_SPACE
from datasets import get_imagenet_dataset
import pickle
import ml_collections
import matplotlib
from matplotlib import pyplot as plt
from thop_modified import profile
from typing import List


# https://github.com/pytorch/pytorch/issues/43250
class MultiCategorical(Distribution):

    def __init__(self, dists: List[Categorical]):
        super().__init__()
        self.dists = dists

    def log_prob(self, value):
        ans = []
        for d, v in zip(self.dists, torch.split(value, 1, dim=-1)):
            ans.append(d.log_prob(v.squeeze(-1)))
        return torch.stack(ans, dim=-1).sum(dim=-1)

    def entropy(self):
        return torch.stack([d.entropy() for d in self.dists], dim=-1).sum(dim=-1)

    def sample(self, sample_shape=torch.Size()):
        return torch.stack([d.sample(sample_shape) for d in self.dists], dim=-1)

    def derive(self):
        return torch.stack([dist.probs.argmax() for dist in self.dists], dim=0)


def multi_categorical_maker(nvec):
    def get_multi_categorical(logits):
        start = 0
        ans = []
        for n in nvec:
            # ans.append(Categorical(logits=logits[:, start: start + n]))
            ans.append(Categorical(logits=logits[start: start + n]))
            start += n
        return MultiCategorical(ans)
    return get_multi_categorical



class Policy(nn.Module):

    def __init__(self, search_space=SEARCH_SPACE):
        # search space: list of int, each represents #actions per dimention
        super(Policy, self).__init__()
        self.search_space = search_space
        self.arch_parameters = nn.ParameterList()
        self.space_dims = []
        self.space_keys = list(search_space.keys())
        for value in search_space.values():
            self.arch_parameters.append(nn.Parameter(1e-3*torch.randn(len(value))))
            self.space_dims.append(len(value))
        self.dist_maker = multi_categorical_maker(self.space_dims)

    def load_arch_params(self, arch_params):
        self.arch_parameters.data.copy_(arch_params)

    def action2arch_str(self, actions):
        _keyactions = {}
        for _idx, (key, value) in enumerate(self.search_space.items()):
            _keyactions[key] = value[actions[_idx].item()]
        arch_str = "{KERNEL_CHOICE1:d},{WINDOW_CHOICE1:d},1,{FFN_EXP_CHOICE1:d}|{KERNEL_CHOICE2:d},{WINDOW_CHOICE2:d},1,{FFN_EXP_CHOICE2:d}|{KERNEL_CHOICE3:d},{WINDOW_CHOICE3:d},1,{FFN_EXP_CHOICE3:d}|{KERNEL_CHOICE4:d},1,1,{FFN_EXP_CHOICE4:d}|{HEAD_CHOICE:d}".format(**_keyactions)
        return arch_str

    def generate_arch(self, arch, image_size, hidden_dim, depth, num_classes=1000, dropout=0, emb_dropout=0):
        if type(arch) in [list, tuple, torch.Tensor]:
            arch_str = self.action2arch_str(arch)
        elif isinstance(arch, str):
            arch_str = arch
        else:
            raise NotImplementedError
        genotype = CellStructure(arch_str)
        return arch_str, Transformer(img_size=image_size, patch_sizes=genotype.patch_sizes, stride=4, in_chans=3, num_classes=num_classes,
                             embed_dim=hidden_dim, depths=depth, num_heads=genotype.heads,
                             window_sizes=genotype.window_sizes, num_mlps=genotype.num_mlps, mlp_ratios=genotype.mlp_ratios,
                             drop_rate=0., attn_drop_rate=0., drop_path_rate=0., use_checkpoint=False)

    def genotype(self):
        self.distribution = self.dist_maker(logits=torch.cat([param for param in self.arch_parameters]))
        genotypes = self.distribution.derive() # ~ tensor of actions
        return self.action2arch_str(genotypes)

    def sample(self):
        self.distribution = self.dist_maker(logits=torch.cat([param for param in self.arch_parameters]))
        actions = self.distribution.sample()
        log_prob = self.distribution.log_prob(actions)
        return actions, log_prob


class ExponentialMovingAverage(object):
    """Class that maintains an exponential moving average."""
    def __init__(self, momentum):
        self._numerator   = 0
        self._denominator = 0
        self._momentum    = momentum

    def update(self, value):
        self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
        self._denominator = self._momentum * self._denominator + (1 - self._momentum)

    def value(self):
        """Return the current value of the moving average"""
        return self._numerator / self._denominator


def main(xargs):
    PID = os.getpid()
    if xargs.timestamp == 'none':
        xargs.timestamp = "{:}".format(time.strftime('%h-%d-%C_%H-%M-%s', time.gmtime(time.time())))

    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads( xargs.workers )

    xargs.reward_types = xargs.reward_types.split('_')
    xargs.save_dir = xargs.save_dir + \
        "/LR%.2f-steps%d-%s-buffer%d-batch%d-repeat%d-flops%.2f"%(xargs.learning_rate, xargs.total_steps, '.'.join(xargs.reward_types), xargs.te_buffer_size, xargs.batch_size, xargs.repeat, xargs.flops) + \
        "/{:}/seed{:}".format(xargs.timestamp, xargs.rand_seed)
    logger = prepare_logger(xargs)

    if xargs.dataset == 'imagenet':
        image_size = 224
    elif xargs.dataset == 'imagenet_64':
        image_size = 64
    else:
        raise NotImplementedError
    logger.log("preparing dataset...")
    dataset_train, loader_train, dataset_eval, loader_eval = get_imagenet_dataset(data_path=xargs.data_path, no_aug=True, img_size=image_size,
                                                                                  batch_size=xargs.batch_size, workers=0)

    eps = np.finfo(np.float32).eps.item()
    logger.log('eps       : {:}'.format(eps))

    # REINFORCE
    trace = []
    total_steps = xargs.total_steps
    hidden_dim = 32
    depth = [1, 1, 1, 1]
    results_summary = {}

    constraint = xargs.flops

    seed = xargs.rand_seed

    print("<< ============== JOB (PID = %d) %s ============== >>"%(PID, '/'.join(xargs.save_dir.split("/")[-5:])))
    _size_curve = 10
    te_reward_generator = TEG(loader_train, loader_eval, size_curve=(_size_curve, 3, image_size, image_size), repeat=xargs.repeat,
                              reward_types=xargs.reward_types, buffer_size=xargs.te_buffer_size, batch_curve=6, constraint_weight=0)
    TOTAL_STEPS = 0
    if constraint not in results_summary: results_summary[constraint] = {}
    prepare_seed(seed)
    te_reward_generator.reset(constraint)
    policy = Policy().cuda()
    optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
    baseline = ExponentialMovingAverage(xargs.EMA_momentum)
    arch_history = [] # save the arch to be derived by the NAS algorithm at anytime
    arch_str_history = [] # save the arch to be derived by the NAS algorithm at anytime
    start_time = time.time()
    time_te_total = 0
    pbar = tqdm(range(total_steps), position=0, leave=True)
    for _step in pbar:
        # record history of policy and optimizer states
        arch_history.append([_alpha.detach().clone().cpu().numpy() for _alpha in policy.arch_parameters])
        np.save(os.path.join(xargs.save_dir, "arch_history_%.1f_s%d.npy"%(constraint, seed)), arch_history)
        np.save(os.path.join(xargs.save_dir, "buffers_%.1f_s%d.npy"%(constraint, seed)), te_reward_generator._buffers)
        np.save(os.path.join(xargs.save_dir, "buffer_changes_%.1f_s%d.npy"%(constraint, seed)), te_reward_generator._buffers_change)
        np.save(os.path.join(xargs.save_dir, "buffers_bad_%.1f_s%d.npy"%(constraint, seed)), te_reward_generator._buffers_bad)

        action, log_prob = policy.sample()
        # print(action)
        arch_str, network = policy.generate_arch(action, image_size, hidden_dim, depth)
        arch_str_history.append(arch_str)
        np.save(os.path.join(xargs.save_dir, "arch_str_history_%.1f_s%d.npy"%(constraint, seed)), arch_str_history)

        _start_time = time.time()
        flops, params = profile(network, inputs=(torch.randn(1, 3, image_size, image_size),), custom_ops={matmul: count_matmul}, verbose=False)
        params = sum(p.numel() for p in network.parameters() if p.requires_grad) # thop did not consider nn.Parameter
        logger.writer.add_scalar("TE/flops", flops, _step + TOTAL_STEPS)
        logger.writer.add_scalar("TE/params", params, _step + TOTAL_STEPS)
        reward = te_reward_generator.step(network, constraint=flops)
        description = " | Params %.0f | FLOPs %.0f"%(params, flops)
        if 'ntk' in te_reward_generator._buffers:
            description += " | NTK %.2f"%te_reward_generator._buffers['ntk'][-1]
            logger.writer.add_scalar("TE/NTK", te_reward_generator._buffers['ntk'][-1], _step + TOTAL_STEPS)
        if 'exp' in te_reward_generator._buffers:
            description += " | Exp %.4f"%te_reward_generator._buffers['exp'][-1]
            logger.writer.add_scalar("TE/Exp", te_reward_generator._buffers['exp'][-1], _step + TOTAL_STEPS)
        time_te_total += (time.time() - _start_time)
        logger.writer.add_scalar("reinforce/entropy", policy.distribution.entropy(), _step + TOTAL_STEPS)
        logger.writer.add_scalar("reward/reward", reward, _step + TOTAL_STEPS)

        pbar.set_description("Entropy {entropy:.2f} | Reward {reward:.2f}".format(entropy=policy.distribution.entropy(), reward=reward) + description)

        trace.append((reward, arch_str))
        baseline.update(reward)
        # calculate loss
        policy_loss = ( -log_prob * (reward - baseline.value()) ).sum()
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
    TOTAL_STEPS += total_steps

    results_summary[constraint][seed] = {}

    arch_str_derived, network_derived = policy.generate_arch(policy.genotype(), image_size, hidden_dim, depth)
    flops, params = profile(network_derived, inputs=(torch.randn(1, 3, image_size, image_size),), custom_ops={matmul: count_matmul}, verbose=False)
    params = sum(p.numel() for p in network_derived.parameters() if p.requires_grad) # thop did not consider nn.Parameter
    te_reward_generator.set_network(network_derived.cuda())
    results_summary[constraint][seed]['derived'] = [arch_str_derived, flops, params, te_reward_generator.get_ntk(), te_reward_generator.get_curve_complexity()]

    best_idx = te_reward_generator._buffer_rank_best()
    arch_str_history_best = arch_str_history[best_idx]
    _, network_history_best = policy.generate_arch(arch_str_history_best, image_size, hidden_dim, depth)
    flops, params = profile(network_history_best, inputs=(torch.randn(1, 3, image_size, image_size),), custom_ops={matmul: count_matmul}, verbose=False)
    params = sum(p.numel() for p in network_history_best.parameters() if p.requires_grad) # thop did not consider nn.Parameter
    te_reward_generator.set_network(network_history_best.cuda())
    results_summary[constraint][seed]['history_best'] = [arch_str_history_best, flops, params, te_reward_generator.get_ntk(), te_reward_generator.get_curve_complexity()]

    np.save(os.path.join(xargs.save_dir, "results_summary.npy"), results_summary)
    logger.log('[Policy] {:s} | flops {:.0f} | params {:.0f} | ntk {:.2f} | exp {:.5f}'.format(*results_summary[constraint][seed]['derived']))

    logger.close()



if __name__ == '__main__':
    parser = argparse.ArgumentParser("Reinforce")
    parser.add_argument('--data_path',          type=str,   help='Path to dataset')
    parser.add_argument('--dataset',            type=str,   default='imagenet', help='Choose between imagenet and imagenet_64.')
    # channels and number-of-cells
    parser.add_argument('--learning_rate',      type=float, default=0.08, help='The learning rate for REINFORCE.')
    parser.add_argument('--total_steps',      type=int, default=500, help='Number of iterations for REINFORCE.')
    parser.add_argument('--EMA_momentum',       type=float, default=0.9, help='The momentum value for EMA.')
    # log
    parser.add_argument('--workers',            type=int,   default=4,    help='number of data loading workers (default: 2)')
    parser.add_argument('--save_dir',           type=str,   help='Folder to save checkpoints and log.')
    parser.add_argument('--rand_seed',          type=int,   default=0,   help='manual seed')
    parser.add_argument('--timestamp', default='none', type=str, help='timestamp for logging naming')
    parser.add_argument('--batch_size',            type=int,   default=16,    help='batch size for ntk')
    parser.add_argument('--repeat',          type=int,   default=5,   help='repeat calculation for TEG')
    parser.add_argument('--te_buffer_size',        type=int,   default=20,   help='buffer size for TE reward generator')
    parser.add_argument('--reward_types',       type=str, default='ntk_exp',  help='type of supernet: basic or nasnet-super')
    parser.add_argument('--flops', type=float, default=0., help='add flops weight into reward/penalty in RL (to be used in TEG)')
    args = parser.parse_args()
    main(args)
