
import random
import numpy as np
import torch
from torch import nn
import gc
from collections import defaultdict
from collections import OrderedDict
import sys
import os

device = torch.device("cuda:0")

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False


def train_model(model, opt, train_data, test_data, args, lamb_lr):
    model.train()

    train_accs, train_losses = [], []
    test_accs, test_losses = [], []
    l_MIs = []
    maxes = []
    lambs = []

    t = 0
    lamb = args.lamb_init

    analyse(model, grads=True)

    for ep in range(args.epochs):
        for xs, ys in train_data:

            if t % 10 == 0:
                train_acc, train_loss = evaluate(model, train_data, args, "train", plot=False)
                test_acc, test_loss = evaluate(model, test_data, args, "test", plot=False)
                train_accs.append(train_acc)
                train_losses.append(train_loss)
                test_accs.append(test_acc)
                test_losses.append(test_loss)
                sys.stdout.flush()

            xs = xs.to(device)
            ys = ys.to(device)

            opt.zero_grad()

            preds, max_prob = model(xs)
            l_sup = nn.functional.cross_entropy(preds, ys, reduction="mean")

            maxes.append(max_prob)

            MI = est_MI(model, train_data.dataset, sz=min(100, len(train_data.dataset)))
            # when this is -ve, lambda becomes smaller. When positive, lambda becomes bigger
            constraint = (args.MI_const - MI)

            l_MIs.append(MI.item())

            loss = l_sup + lamb * constraint

            #print(("training loop", l_sup.item(), MI.item(), constraint.item(), lamb))

            loss.backward()

            opt.step()

            analyse(model, grads=True, t=t)

            lamb += lamb_lr * constraint.item() # gradient ascent
            lambs.append(lamb)

            t += 1

    MI = est_MI(model, train_data.dataset, sz=min(1000, len(train_data.dataset)), requires_grad=False)

    diagnostics = {"train_losses": train_losses,
                   "train_accs": train_accs,
                   "test_losses": test_losses,
                   "test_accs": test_accs,
                   "l_MIs": l_MIs,
                   "maxes": maxes,
                   "lambs": lambs,
                   }
    return model, MI.item(), diagnostics


def est_MI(model, dataset, sz, jensen, requires_grad=True):
    ii = np.random.choice(len(dataset), size=sz, replace=False)
    x = torch.stack([dataset[i][0] for i in ii], dim=0).to(device)

    if not requires_grad:
        model.eval()
        with torch.no_grad():
            z, log_prob = model(x, repr=True)
            log_marg_prob = model.log_marg_prob(z, x, jensen=jensen)
        model.train()
    else:
        z, log_prob = model(x, repr=True)
        log_marg_prob = model.log_marg_prob(z, x, jensen=jensen)

    return (log_prob - log_marg_prob).mean()



def est_MI_cond(model, num_classes, dl, jensen):
    model.eval()

    x_class = [[] for _ in range(num_classes)]
    counts = torch.zeros(num_classes, device=device)

    for xs, ys in dl:
        xs = xs.to(device)
        ys = ys.to(device)

        for c in range(num_classes):
            c_inds = ys == c
            x_class[c].append(xs[c_inds])
            counts[c] += c_inds.sum()

    MIs = torch.zeros(num_classes, device=device)
    for c in range(num_classes):
        x_class[c] = torch.cat(x_class[c], dim=0)

        with torch.no_grad():
            z, log_prob = model(x_class[c], repr=True)
            log_marg_prob = model.log_marg_prob(z, x_class[c], jensen=jensen)
            MIs[c] = (log_prob - log_marg_prob).mean()

    counts = counts / counts.sum()
    assert (counts.shape == MIs.shape)
    MI_avg = (MIs * counts).sum()

    model.train()
    return MI_avg




def evaluate(model, data, args, s, plot=False):
    model.eval()


    accs = []
    losses = []

    all_hard = []
    all_xs = []
    with torch.no_grad():
        for xs, ys in data:
            xs = xs.to(device)
            ys = ys.to(device)

            preds, _ = model(xs)

            loss = torch.nn.functional.cross_entropy(preds, ys, reduction="none")
            losses.append(loss)

            hard = preds.argmax(dim=1)
            acc = (hard == ys).to(torch.float)
            accs.append(acc)

            all_hard.append(hard)
            all_xs.append(xs)

    if plot:
        f, ax = plt.subplots(1)

        all_xs = torch.cat(all_xs, dim=0)
        all_hard = torch.cat(all_hard, dim=0)
        for c in range(args.C):
            ax.scatter(all_xs[all_hard == c, 0].cpu().numpy(), all_xs[all_hard == c, 1].cpu().numpy())

        plt.tight_layout()
        f.savefig(os.path.join(args.out_dir, "preds_%s.png" % s), bbox_inches="tight")
        plt.close("all")

    model.train()
    return torch.cat(accs).mean().item(), torch.cat(losses).mean().item()


def analyse(model, grads=True, t=None):
    all_val = []
    all_grads = []
    all_val_m = []
    all_grads_m = []
    for p in model.parameters():
        all_val.append(p.data.abs().max().item())
        all_val_m.append(p.data.abs().mean().item())
        if grads and p.grad is not None:
            all_grads.append(p.grad.abs().max().item())
            all_grads_m.append(p.grad.abs().mean().item())

    val_m = np.array(all_val_m).mean()
    max_grad = None
    grad_m = None
    if grads and len(all_grads) > 0:
        max_grad = max(all_grads)
        grad_m = np.array(all_grads_m).mean()
    print("\t analyse %s: params max %s mean %s, grads max %s mean %s" % (
    t, max(all_val), val_m, max_grad, grad_m))
    return val_m


def clean(s):
    return s.replace(" ", "_")