import json
import os

import torch
import torch.optim as optim
from torchvision.models.densenet import densenet121

from data_utils.data_manager import DataManager
from trainer import Trainer
from utils.data_logs import save_logs_about
import utils.losses as loss_functions
from torch.utils.tensorboard import SummaryWriter


def get_optimizer_curriculum_lr(model, initial_lr):
    lr_blocks = {
        "denseblock1": 1e-1,
        "denseblock2": 1e-1,
        "denseblock3": 1e-2,
        "denseblock4": 1e-4,
        "classifier": 1e-5
    }
    params_list = []
    for name, param in model.named_parameters():
        if name.split(".")[1] in lr_blocks.keys():
            lr = initial_lr * lr_blocks[name.split(".")[1]]
            params_list.append({'params': param, 'lr': lr})
        elif name.split(".")[0] in lr_blocks.keys():
            lr = initial_lr * lr_blocks[name.split(".")[0]]
            params_list.append({'params': param, 'lr': lr})
        else:
            params_list.append({'params': param, 'lr': initial_lr})
    return optim.Adam(params_list, lr=initial_lr)


def main():
    config = json.load(open('./config.json'))
    config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'

    try:
        os.mkdir(os.path.join(config['exp_path'], config['exp_name']))
    except FileExistsError:
        print("Director already exists! It will be overwritten!")

    logs_writer = SummaryWriter(os.path.join('runs', config['exp_name']))

    model = densenet121(pretrained=True)
    model.classifier = torch.nn.Linear(1024, 6)
    model.features.conv0 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=True)
    model.to(config['device'])
    model.float()

    # Save info about experiment
    save_logs_about(os.path.join(config['exp_path'], config['exp_name']), json.dumps(config, indent=2))
    # shutil.copy(model.get_path(), os.path.join(config['exp_path'], config['exp_name']))

    criterion = getattr(loss_functions, config['loss_function'])

    if config['enable_curriculum_lr']:
        optimizer = get_optimizer_curriculum_lr(model, config['lr'])
    else:
        optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config['lr_sch_step'], gamma=config['lr_sch_gamma'])

    data_manager = DataManager(config)
    train_loader, validation_loader, test_loader = data_manager.get_train_eval_test_dataloaders()

    trainer = Trainer(model, train_loader, validation_loader, criterion, optimizer, lr_scheduler, logs_writer, config)
    trainer.train()

    trainer.test_net(test_loader)


if __name__ == "__main__":
    main()
