import argparse
import torch
import numpy as np
import random
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import random_split

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

from network import LearnableMatrix
from dataset import Cifar10h
from utils import calc_mutual_info, logits_labels

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


parser = argparse.ArgumentParser()

parser.add_argument('path', help="path where the files are saved", type=str)  

args = parser.parse_args() 

def train(n_epochs, net, labels, beta, path, optimizer, scheduler, device=device):
    net.train()
    running_IBs = []
    os.makedirs(path, exist_ok=True)
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        out = net()
        i_x_t, i_t_y = calc_mutual_info(out, labels, stochastic_label=True)
        loss = -i_t_y + beta*i_x_t
        # print(i_x_t.item(), i_t_y.item())
        loss.backward()
        optimizer.step()
        running_IBs.append(loss.item())
        if scheduler is not None:
            scheduler.step()
        
    running_IBs = np.array(running_IBs)
    mis = np.array([i_x_t.item(), i_t_y.item()])
    np.save(path + "/ib_loss", running_IBs)
    torch.save(net.state_dict(), path + '/weight.pth')
    return mis
    
# fix seed
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)


dataset = Cifar10h('/Volumes/csbdeep15/sota_ku/cifar10/')
labels = torch.cat([dataset[i][1].unsqueeze(0) for i in range(len(dataset))], dim=0).to(device)

n_epochs = 1000
lr = 0.01

betas = np.linspace(0,1,11)
all_mis = []
for beta in tqdm(betas):
    net = LearnableMatrix(10000, 10).to(device)
    optimizer = optim.Adam(net.parameters(), lr=lr) 
    scheduler = None
    mis = train(n_epochs,net, labels, beta, args.path+f"/{beta}", optimizer, scheduler)
    all_mis.append(mis)
all_mis = np.stack(all_mis, axis = 1)
np.save(args.path + "/mis", all_mis)


alphas = np.linspace(0, 1, 101)
ls_mis = []
for alpha in alphas:
    out = (1-alpha)*labels + alpha*0.1
    i_x_t, i_t_y = calc_mutual_info(out, labels, stochastic_label=True, input_probs_not_logits=True)
    ls_mis.append([i_x_t.item(), i_t_y.item()])
    
ls_mis = np.array(ls_mis).T
np.save(args.path + "/ls_mis", ls_mis)

