import os
import random
import pyrallis
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from models.base import load_model
from optimization.optimizer import load_optimizer
from loss.task_loss import load_task_loss
from metrics.accuracy import accuracy
from utils.train_util import AverageMeter, ProgressMeter
from utils.log_utils import log_scalar_dict, create_experiment_dir

from options import ModelBaseTrainConfig

@pyrallis.wrap()
def main(cfg: ModelBaseTrainConfig):
    init_seed(2025)
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = not cfg.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # ----------------------------------------
    # logging configuration
    # ----------------------------------------

    if not cfg.logging.disable_logging:
        logger = SummaryWriter(log_dir=os.path.join(cfg.logging.log_dir, "runs", cfg.logging.exp_name, cfg.logging.sub_exp_name))
        logger.add_text("Config", json.dumps(pyrallis.encode(cfg), indent=4))
    else:
        logger = None
        
    exp_dir_path = create_experiment_dir(cfg.logging.log_dir, cfg.logging.exp_name, 'cont', cfg.logging.sub_exp_name)

    # ----------------------------------------
    # model base configuration
    # ----------------------------------------
    model = load_model(cfg.task.original_model_name, num_classes=cfg.task.num_classes).to(device)
    if cfg.checkpoint_path is not None:
        model.load_state_dict(torch.load(cfg.checkpoint_path, map_location=device, weights_only=False))
    
    # ----------------------------------------
    # data loader configuration
    # ----------------------------------------
    train_loader = load_images(cfg.task.data_dir, cfg.task.task_name, data_type='train', path_prefix=cfg.task.path_prefix, batch_size=cfg.batch_size)
    test_loader = load_images(cfg.task.data_dir, cfg.task.task_name, data_type='test', path_prefix=cfg.task.path_prefix, batch_size=cfg.batch_size)

    # ----------------------------------------
    # opimization configuration
    # ----------------------------------------
    criterion = load_task_loss(cfg.task)
    optimizer = load_optimizer(model.parameters(), cfg.optim)

    # ----------------------------------------
    # each epoch
    # ----------------------------------------
    for epoch in tqdm(range(cfg.epochs), desc='Epoch'):
        loss, acc = train(train_loader, model, criterion, optimizer, device)
        eval_loss, eval_acc = test(test_loader, model, criterion, device)
        
        loss_acc_dict = {
            'original model training loss': loss.avg,
            'original model training acc': acc.avg,
            'original model test loss': eval_loss.avg,
            'original model test acc': eval_acc.avg
        }
        
        log_scalar_dict(loss_acc_dict,
                        title='model_base_cont',
                        iteration=epoch,
                        logger=logger)
        # ----------------------------------------
        # save best model
        # ----------------------------------------
        torch.save(model.state_dict(), os.path.join(exp_dir_path, f'model_cont{epoch+1}.pt'))

    print('MODEL DIR', exp_dir_path)


def train(train_loader, model, criterion, optimizer, device):
    model.train()

    loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(total=len(train_loader), step=20, prefix='Train',
                             meters=[loss_meter, acc_meter])

    for i, samples in enumerate(train_loader):
        inputs, labels = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        acc, acc5 = accuracy(outputs, labels, topk=(1, 5))

        loss_meter.update(loss.item(), inputs.size(0))
        acc_meter.update(acc.item(), inputs.size(0))

        optimizer.zero_grad()  # 1
        loss.backward()  # 2
        optimizer.step()  # 3

        progress.display(i)

    return loss_meter, acc_meter


def test(test_loader, model, criterion, device):
    model.eval()

    loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test',
                             meters=[loss_meter, acc_meter])

    for i, samples in enumerate(test_loader):
        inputs, labels = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            acc, acc5 = accuracy(outputs, labels, topk=(1, 5))

            loss_meter.update(loss.item(), inputs.size(0))
            acc_meter.update(acc.item(), inputs.size(0))

            progress.display(i)

    return loss_meter, acc_meter

def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
