import numpy as np
import torch as t
from torch import nn
import random
import os
from utils.visualiser import visualiseModel, plot_gif
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
from argparse import Namespace
device = "cuda" if t.cuda.is_available() else "cpu"
from collections import defaultdict


def get_acc(ps, ys, p_thresh=0.5, y_thresh=0.01):
    return float((t.sigmoid(ps).flatten().ge(p_thresh) == ys.flatten().ge(y_thresh)).double().mean())


def smooth_labels(t, epsilon=0.01):
    return t * (1-epsilon) + epsilon * (1 - t)


def save_model(d, path, checkpoint=None):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    if checkpoint is not None:
        path += "_checkpoint_{}".format(checkpoint+1)
    t.save(d, path)


def save_extractor_model(CONFIG, model, checkpoint=None):
    d = {
        'CONFIG': vars(CONFIG),
        'model_state_dict': model.state_dict(),
        'modelname': model.modelname,
        'modelinits': model.initparams
    }
    path = CONFIG.save_path
    save_model(d, path, checkpoint)


def get_opt_state(opt, parameters):
    return [opt.state[p] for p in parameters]


def set_opt_state(opt, parameters, values):
    for p, v in zip(parameters, values):
        opt.state[p] = v


def save_joint_model(CONFIG, extractor, updater, optimizers=[], checkpoint=None):
    d = {
        'CONFIG': vars(CONFIG),
        'model_state_dict': extractor.state_dict(),
        'modelname': extractor.modelname,
        'modelinits': extractor.initparams,
        'updater_state_dict': updater.state_dict(),
        'updatername': updater.modelname,
        'updaterinits': updater.initparams
    }
    if optimizers:
        d["modelopt"] = get_opt_state(optimizers[0], extractor.parameters())
        d["updateropt"] = get_opt_state(optimizers[1], updater.parameters())
    path = CONFIG.save_path
    save_model(d, path, checkpoint)


def init_random_world_state(CONFIG, datahandler):
    world_states = t.randn((CONFIG.n_worlds, CONFIG.model_vars["model_inits"]["world_state_size"]), device=device)
    datahandler.init_run(CONFIG)
    if CONFIG.init_scale >= 0:
        world_states *= CONFIG.init_scale / t.norm(world_states, dim=1, p=2).unsqueeze(1)
    elif CONFIG.ws_norm != -1:
        world_states /= t.norm(world_states, dim=1, p=CONFIG.ws_norm).unsqueeze(1)
    return world_states


def advance_state(CONFIG, datahandler, world_states, updater, sample_test=False):
    datahandler.advance_run(CONFIG)
    samples = datahandler.get_samples(CONFIG, sample_test=sample_test)
    if sample_test:
        triples_train, targets_train, triples_test, targets_test = samples
    else:
        triples_train, targets_train = samples
    new_world_states = updater.forward(triples_train,
                                       targets_train.unsqueeze(2), world_states)
    if CONFIG.ws_norm != -1:
        new_world_states = new_world_states / t.norm(new_world_states, dim=1, p=CONFIG.ws_norm).unsqueeze(1)
    return samples, new_world_states


def train_advance_state(CONFIG, datahandler, world_states, updater, extractor, loss_func=None, train_updater=True,
                        train_extractor=True, callbacks=[]):
    optimizers = []
    if train_extractor is True or train_extractor is None:
        optimizers += [t.optim.SGD(list(extractor.parameters()), lr=CONFIG.lr_net)]
    elif train_extractor is not False:
        optimizers += [train_extractor]
    if train_updater is True or train_updater is None:
        optimizers += [t.optim.SGD(list(updater.parameters()), lr=CONFIG.updater_lr_net)]
    elif train_updater is not False:
        optimizers += [train_updater]
    for callback in callbacks:
        callback(Namespace(**locals()), "start")
    if hasattr(CONFIG,"learnable_init_world_state") and CONFIG.learnable_init_world_state:
        current_world_states = updater.get_init_world_state(world_states)
    else:
        current_world_states = world_states
    datahandler.reset_run()
    for opt in optimizers:
        opt.zero_grad()
    for step in range(CONFIG.n_updater_steps):
        if train_updater is not False:
            updater.train()
        if train_extractor is not False:
            extractor.train()
        samples, new_world_states = advance_state(CONFIG, datahandler, current_world_states, updater, sample_test=True)
        triples_train, targets_train, triples_test, targets_test = samples
        if "direct_samples" in CONFIG:
            triples_direct = triples_train[:, :CONFIG.direct_samples]
            targets_direct = targets_train[:, :CONFIG.direct_samples]
        else:
            triples_direct = triples_train
            targets_direct = targets_train
        direct_count = triples_direct.shape[1]
        triples = t.cat([triples_direct, triples_test], dim=1)
        targets = t.cat([targets_direct, targets_test], dim=1).unsqueeze(2)
        preds = extractor.forward(triples, new_world_states)
        sub_loss_metrics = {}
        loss_upd = 0
        for callback in callbacks:
            sub_loss = callback(Namespace(**locals()), "loss")  # should return (loss: scalar, metrics: dict)
            if sub_loss is not None:
                loss_upd += sub_loss[0]
                sub_loss_metrics.update(sub_loss[1])
        if loss_func is not None:
            loss_upd += loss_func(preds, smooth_labels(targets, epsilon=CONFIG.updater_smoothing_eps)) + \
                       CONFIG.updater_ws_reg * (new_world_states**2).mean()
            loss_upd.backward()
            if CONFIG.gradient_step_ratio and (step + 1) % CONFIG.gradient_step_ratio == 0:
                for opt in optimizers:
                    opt.step()
                    opt.zero_grad()
        stop = False
        for callback in callbacks:
            if callback(Namespace(**locals()), "iter"):
                stop = True
        current_world_states = new_world_states
        if stop:
            break
    if not CONFIG.gradient_step_ratio:
        for opt in optimizers:
            opt.step()
    for callback in callbacks:
        callback(Namespace(**locals()), "end")


def logging_step(world_reset, local):
    return world_reset % local.CONFIG.test_log_freq == 0


def attach_advance_step(d, step):
    return dict((k+"_{}".format(step), v) for k,v in d.items())


def log_train_advance_state(logger, world_reset, return_metrics=False):
    all_metrics = defaultdict(lambda:defaultdict(lambda:0))
    def callback(local, step):
        if world_reset % local.CONFIG.test_log_freq != 0:
            return
        if step == "iter":
            metrics = all_metrics[local.step]
            preds_view = local.preds.view(local.CONFIG.n_worlds, -1)
            Y_view = local.targets.view(local.CONFIG.n_worlds, -1)
            metrics["accDirect"] += get_acc(preds_view[:, :local.direct_count], Y_view[:, :local.direct_count], *local.CONFIG.acc_thresh)
            metrics["accIndirect"] += get_acc(preds_view[:, local.direct_count:], Y_view[:, local.direct_count:], *local.CONFIG.acc_thresh)
            try:
                metrics["loss_upd"] += float(local.loss_upd)
            except:
                pass
            metrics["state_magnitude"] += float(local.new_world_states.detach().abs().sum(dim=1).mean())
            for key, val in local.sub_loss_metrics.items():
                metrics[key] += val
            with t.no_grad():
                log_dict = {}
                if local.step == 0:
                    log_dict['initstate_magnitude'] = float(local.current_world_states.detach().abs().sum(dim=1).mean())
                for metric, total in metrics.items():
                        log_dict[metric] = total
                log_dict["accTotal"] = (log_dict["accDirect"] + log_dict["accIndirect"]) / 2
                if isinstance(logger, dict):
                    logger[local.step] += [log_dict]
                else:
                    logger.log(attach_advance_step(log_dict, local.step),
                              step=world_reset)
    if return_metrics:
        return callback, all_metrics
    else:
        return callback


def log_dict_mean(logs, CONFIG, wandb=None):
    ret = []
    for step in range(CONFIG.n_updater_steps):
        step_log = logs[step]
        mean_log = {}
        if step_log:
            for key in step_log[0].keys():
                mean_log[key] = np.mean([log[key] for log in step_log])
        if wandb is not None:
            wandb.log(mean_log, step=step)
        ret += [mean_log]
    return ret

def classification_loss_callback(loss_func):
    def callback(local, step):
        if step == "loss":
            y = local.datahandler.get_labels(local.CONFIG).unsqueeze(1).unsqueeze(2)
            world_states = local.new_world_states
            if hasattr(local.CONFIG, "detach_class_loss") and local.CONFIG.detach_class_loss:
                world_states = world_states.detach()
            preds = local.extractor.forward(t.zeros((y.shape[0], 1, 1), dtype=t.long), world_states)
            class_acc = get_acc(preds, y)
            class_loss = loss_func(preds, smooth_labels(y, epsilon=local.CONFIG.smoothing_eps))
            return class_loss * local.CONFIG.class_loss_scale, {"class_acc": float(class_acc), "class_loss": float(class_loss)}
    return callback


def visualise_model(wandb, world_reset, init=True, update_step_number_for_init=False, gif_mode=False, labeled=False, progressive=False):
    gif_frames = []
    def callback(local, step):
        if init is True:
            if step == "end":
                name = "reconstruction"
                state = local.datahandler.init_state
                if update_step_number_for_init:
                    step_number = world_reset
                else:
                    step_number = (world_reset + 1) * local.CONFIG.world_state_init_steps
                num_relations = local.CONFIG.max_init - local.CONFIG.min_init + 1
                init_frame = local.CONFIG.min_init
                relation_offset = local.CONFIG.min_init
                world_state = local.world_states
            else:
                return
        else:
            if step == "iter" and logging_step(world_reset, local):
                name = "reconstruction_advance_{}".format(local.step + 1)
                state = local.datahandler.current_state
                step_number = world_reset
                init_frame = min(local.CONFIG.min_train, local.CONFIG.min_test)
                num_relations = max(local.CONFIG.max_train, local.CONFIG.max_test) + 1 - init_frame
                relation_offset = init_frame
                world_state = local.new_world_states
            elif step == "end" and gif_frames:
                gif = plot_gif(gif_frames)
                step_number = world_reset
                wandb.log({"reconstruction_advance": gif}, step = step_number)
                gif_frames.clear()
                return
            else:
                return
        local.extractor.eval()
        with t.no_grad():
            vis = visualiseModel(local.CONFIG, local.extractor, local.datahandler, world_state,
                                 run=(state, local.triples_train, local.targets_train),
                                 batch_size=local.CONFIG.n_worlds,
                                 init_frame=init_frame + local.datahandler.step,
                                 relation_offset=relation_offset, labeled=labeled, progressive=progressive)
        if gif_mode:
            gif_frames.append(vis)
        else:
            wandb.log({name: [wandb.Image(vis)]}, step=step_number)
    return callback

def save_extractor_callback(world_reset):
    def callback(local, step):
        if step == "end":
            if (world_reset + 1) % local.CONFIG.save_reset_freq == 0:
                save_extractor_model(local.CONFIG, local.extractor, checkpoint=world_reset+1)
    return callback

def save_joint_callback(world_reset):
    def callback(local, step):
        if step == "end":
            if (world_reset + 1) % local.CONFIG.save_reset_freq == 0:
                save_joint_model(local.CONFIG, local.extractor, local.updater,
                                 optimizers=local.optimizers, checkpoint=world_reset+1)
    return callback




def max_step_callback(world_reset, max_steps, logger=None):
    def callback(local, step):

        if step == "iter":
            if local.step == max_steps - 1:
                return True
        if step == "end":
            if logger is not None and world_reset % local.CONFIG.test_log_freq == 0:
                logger.log({"max_steps":max_steps}, step=world_reset)
    return callback


def parse_scheduler(scheduler, optimizer):
    sched, args, kwargs = scheduler
    if sched == "cosine":
        return CosineAnnealingWarmRestarts(optimizer, *args, **kwargs)
    elif sched == "linear_warmup": #strategy from https://arxiv.org/pdf/1910.04209.pdf
        beta = kwargs.get("beta",.999)
        tau = 2 / (1 - beta)
        lr_lambda = lambda epoch: min(1, epoch / tau)
        return LambdaLR(optimizer, lr_lambda=lr_lambda)


def get_scheduler(CONFIG, optimizer_extractor, optimizer_updater):
    schedulers = []
    if CONFIG.lr_scheduler is not None:
        schedulers += [parse_scheduler(CONFIG.lr_scheduler, optimizer_extractor)]
    if CONFIG.updater_lr_scheduler is not None:
        schedulers += [parse_scheduler(CONFIG.updater_lr_scheduler, optimizer_updater)]
    def scheduler_callback(world_reset):
        def callback(local, step):
            if step == "end":
                for scheduler in schedulers:
                    scheduler.step(world_reset)
        return callback
    return scheduler_callback

