import argparse
from copy import deepcopy

import numpy as np
import torch
from tqdm import tqdm

import configs
from data.load import get_ada_cl
from models.resnet import reduced_ResNet18

dataset = "cifar100"
config = getattr(configs, f"{dataset}_augment")
args = argparse.Namespace(**config)
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# ada_cl, test_cl_list, train_list = get_ada_cl(**vars(args), return_train_list = True)
ada_cl, test_cl_list = get_ada_cl(**vars(args), return_train_list=False)

# pretrain
model_list = [
    reduced_ResNet18(args.n_classes // args.n_tasks).to(device)
    for _ in range(args.n_tasks)
]
op_list = [torch.optim.SGD(model.parameters(), lr=args.lr) for model in model_list]
criterion = torch.nn.CrossEntropyLoss(reduction="mean")
#
# for task_index, dataloder in enumerate(tqdm(train_list)):
#     for time, (inputs, labels) in enumerate(dataloder):
#         inputs, labels = inputs.to(device), labels.to(device)
#         labels = labels % (args.n_classes // args.n_tasks)
#
#         op_list[task_index].zero_grad()
#         loss = criterion(model_list[task_index](inputs), labels)
#         loss.backward()
#         op_list[task_index].step()
#         if time == len(dataloder)//3:
#             path = f'./store/roughtrained/{dataset}/'
#             if os.path.exists(path) == False:
#                 os.makedirs(path)
#             torch.save(model_list[task_index].state_dict(), path+f'task_{task_index}.pth')
#     path = f'./store/pretrained/{dataset}/'
#     if os.path.exists(path) == False:
#         os.makedirs(path)
#     torch.save(model_list[task_index].state_dict(), path + f'/task_{task_index}.pth')

for task_index in range(args.n_tasks):
    model_list[task_index].load_state_dict(
        torch.load(f"./store/roughtrained/{dataset}/task_{task_index}.pth")
    )

model_list_2 = [deepcopy(model) for model in model_list]
op_list_2 = [torch.optim.SGD(model.parameters(), lr=args.lr) for model in model_list_2]


total_correct, time = 0, 0
total_correct_2 = 0
idx_correct = 0
count = 0
weights = torch.ones(args.n_tasks, device=device)
weights /= weights.sum()
cum_weights = weights.clone()
avg_loss_sq = torch.ones(args.n_tasks, device=device)


avg_loss = torch.zeros(args.n_tasks, device=device)

for inputs, labels in tqdm(ada_cl):
    time += 1
    inputs, labels = inputs.to(device), labels.to(device)
    task_index = labels[0] // (args.n_classes // args.n_tasks)
    labels = labels % (args.n_classes // args.n_tasks)
    count += len(labels)
    model_list[task_index].eval()
    with torch.no_grad():
        preds = model_list[task_index](inputs).argmax(1)
    total_correct += (preds == labels).sum().item()

    model_list[task_index].train()
    op_list[task_index].zero_grad()
    loss = criterion(model_list[task_index](inputs), labels)
    loss.backward()
    op_list[task_index].step()

    model_list_2[task_index].eval()
    with torch.no_grad():
        preds2 = None
        for i in range(args.n_tasks):
            new_pred = weights[i].item() * model_list_2[i](inputs).softmax(1)
            if preds2 is None:
                preds2 = new_pred
            else:
                preds2 += new_pred
    total_correct_2 += (preds2.argmax(1) == labels).sum().item()
    with torch.no_grad():
        loss = torch.tensor(
            [criterion(model(inputs), labels).item() for model in model_list_2]
        ).to(device)
        avg_loss -= 0.1 * (avg_loss - loss)

        weights = torch.zeros(args.n_tasks, device=device)
        weights[avg_loss.argmin()] = 1

        idx_correct += weights[task_index].item()

    # refine
    if weights.max() > 0.5:
        i = weights.argmax().item()

        model_list_2[i].train()
        op_list_2[i].zero_grad()
        loss = criterion(model_list_2[i](inputs), labels)
        loss.backward()
        op_list_2[i].step()
        model_list_2[i].eval()


print(total_correct / count)
print(total_correct_2 / count)
print(idx_correct / time)

for model in model_list:
    model.eval()
for model in model_list_2:
    model.eval()

acc = torch.zeros(args.n_tasks, device=device)
acc_2 = torch.zeros(args.n_tasks, device=device)
for task_index, dataloder in enumerate(tqdm(test_cl_list)):
    correct1, correct2, total = 0, 0, 0
    for inputs, labels in dataloder:
        inputs, labels = inputs.to(device), labels.to(device)
        labels = labels % (args.n_classes // args.n_tasks)
        with torch.no_grad():
            preds = model_list[task_index](inputs).argmax(1)
            correct1 += (preds == labels).sum().item()
            preds2 = model_list_2[task_index](inputs).argmax(1)
            correct2 += (preds2 == labels).sum().item()

        total += len(labels)
    acc[task_index] = correct1 / total
    acc_2[task_index] = correct2 / total

print(acc)
print(acc_2)
