import argparse
import os
import sys
import json
import imageio
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from task_generator import generate_and_plot_data, sample_minibatch
from utils import savefig, plot_current_status

sys.path.insert(1, os.path.join(sys.path[0], ".."))
from divdis import DivDisLoss


parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--test_batch_size", type=int, default=100)
parser.add_argument("--train_iter", type=int, default=20000)
parser.add_argument("--log_every", type=int, default=500)
parser.add_argument("--plot_every", type=int, default=1000)

parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--heads", type=int, default=2)
parser.add_argument("--aux_weight", type=float, default=1.0) # DivDis 1.0, D-BAT 10.0
parser.add_argument("--mode", type=str, default="mi")
parser.add_argument("--reduction", type=str, default="mean")
# new added
parser.add_argument("--train_size", type=int, default=500)
parser.add_argument("--test_size", type=int, default=5000)
parser.add_argument("--method", type=str, default="DivDis")
parser.add_argument("--predefined_erm", action="store_true", default=False)
parser.add_argument("--degree_of_balance", type=float, default=None)  # [0, 1], where 1 is balanced

args = parser.parse_args()

training_data, test_data = generate_and_plot_data(args)

exp_name = f"linear_{args.method}_balance{args.degree_of_balance}_alpha{args.aux_weight}"
print(exp_name)
logs = {}
logs["slopes"] = []
logs["xent"] = []
logs["adv_loss"] = []

if args.method == "DivDis":
    net = nn.Linear(2, args.heads, bias=True).cuda()
    opt = torch.optim.Adam(net.parameters(), lr=args.lr)
    loss_fn = DivDisLoss(heads=args.heads, mode=args.mode, reduction=args.reduction)

    for t in range(args.train_iter + 1):
        x, y = sample_minibatch(training_data, args.batch_size)
        x, y = x.cuda(), y.cuda()
        logits = net(x)
        logits_chunked = torch.chunk(logits, args.heads, dim=-1)
        losses = [F.binary_cross_entropy_with_logits(logit, y) for logit in logits_chunked]
        xent = sum(losses)

        target_x, _ = sample_minibatch(test_data, args.test_batch_size)
        target_x = target_x.cuda()
        target_logits = net(target_x)
        repulsion_loss = loss_fn(target_logits)

        full_loss = xent + args.aux_weight * repulsion_loss
        opt.zero_grad()
        full_loss.backward()
        opt.step()

        times, slopes = plot_current_status(args, net, t, xent, repulsion_loss, training_data, test_data, exp_name)
        if len(slopes) != 0:
            logs["slopes"].append(slopes)
            logs["adv_loss"].append(repulsion_loss.item())
            logs["xent"].append(xent.item())

elif args.method == "DivDis-Seq":
    assert args.heads == 2  #For now. DivDis-Seq only supports 2 heads
    nets = [nn.Linear(2, 1, bias=True).cuda() for _ in range(args.heads)]
    loss_fn = DivDisLoss(heads=args.heads, mode=args.mode, reduction=args.reduction)
    for i in range(args.heads):
        # whether to use a predefined ERM model
        if i == 0 and args.predefined_erm:
            nets[0].weight.data = torch.tensor([[0., 1000.]]).cuda()
            continue
        opt = torch.optim.Adam(nets[i].parameters(), lr=args.lr)
        for t in range(args.train_iter + 1):
            x, y = sample_minibatch(training_data, args.batch_size)
            x, y = x.cuda(), y.cuda()
            logits = nets[i](x)
            xent = F.binary_cross_entropy_with_logits(logits, y)

            if i != 0:
                target_x, _ = sample_minibatch(test_data, args.test_batch_size)
                target_x = target_x.cuda()
                with torch.no_grad():
                    target_logits_1 = nets[0](target_x)
                target_logits_2 = nets[i](target_x)
                target_logits = torch.cat([target_logits_1, target_logits_2], dim=-1)
                repulsion_loss = loss_fn(target_logits)
                full_loss = xent + args.aux_weight * repulsion_loss
            else:
                repulsion_loss = torch.zeros(1)
                full_loss = xent
            opt.zero_grad()
            full_loss.backward()
            opt.step()

            times, slopes = plot_current_status(args, nets, t, xent, repulsion_loss, training_data, test_data, exp_name)
            if len(slopes) != 0:
                logs["slopes"].append(slopes)
                logs["adv_loss"].append(repulsion_loss.item())
                logs["xent"].append(xent.item())

elif args.method == "D-BAT":
    assert args.heads == 2  #For now. D-BAT only supports 2 heads
    nets = [nn.Linear(2, 1, bias=True).cuda() for _ in range(args.heads)]
    for i in range(args.heads):
        # whether to use a predefined ERM model
        if i == 0 and args.predefined_erm:
            nets[0].weight.data = torch.tensor([[0., 1000.]]).cuda()
            continue
        opt = torch.optim.Adam(nets[i].parameters(), lr=args.lr)
        for t in range(args.train_iter + 1):
            x, y = sample_minibatch(training_data, args.batch_size)
            x, y = x.cuda(), y.cuda()
            logits = nets[i](x)
            xent = F.binary_cross_entropy_with_logits(logits, y)

            if i != 0:
                target_x, _ = sample_minibatch(test_data, args.test_batch_size)
                target_x = target_x.cuda()
                with torch.no_grad():
                    p_1 = torch.sigmoid(nets[0](target_x))
                p_2 = torch.sigmoid(nets[i](target_x))
                repulsion_loss = - ((1.-p_1) * p_2 + p_1 * (1.-p_2) + 1e-5).log().mean()
                full_loss = xent + args.aux_weight * repulsion_loss
            else:
                repulsion_loss = torch.zeros(1)
                full_loss = xent
            opt.zero_grad()
            full_loss.backward()
            opt.step()

            times, slopes = plot_current_status(args, nets, t, xent, repulsion_loss, training_data, test_data, exp_name)
            if len(slopes) != 0:
                logs["slopes"].append(slopes)
                logs["adv_loss"].append(repulsion_loss.item())
                logs["xent"].append(xent.item())

logs["time"] = times


# Test Eval
te_x, te_y = test_data
logits = []
assert args.test_size % args.test_batch_size == 0
for minibatch_idx in range(0, len(te_x), args.test_batch_size):
    curr_te_x = torch.tensor(te_x[minibatch_idx: minibatch_idx + args.test_batch_size]).float().cuda()
    if args.method == "DivDis":
        logits.append(net(curr_te_x))
    elif args.method == "D-BAT" or args.method == "DivDis-Seq":
        logits.append(torch.cat([nets[0](curr_te_x), nets[1](curr_te_x)], dim=-1))
logits = torch.cat(logits, dim=0)
test_acc = ((np.array(logits.detach().cpu()) > 0) == te_y).mean(0)
print("Test Accuracy:", test_acc)
logs["test_acc"] = test_acc.tolist()

filenames = [f"figures/linear/{exp_name}_{t=}.png" for t in times]
images = [imageio.imread(filename) for filename in filenames]
os.makedirs("gifs", exist_ok=True)
imageio.mimsave(f"gifs/{exp_name}.gif", images)
print(f"Saved gif to gifs/{exp_name}.gif")


with open(f"./results/{exp_name}", "w") as fp:
    json.dump(logs,fp) 
