import os
import random
import pyrallis
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from models.base import load_model
from optimization.optimizer import load_optimizer
from loss.task_loss import load_task_loss
from metrics.accuracy import accuracy
from utils.train_util import AverageMeter, ProgressMeter
from utils.log_utils import log_scalar_dict, create_experiment_dir

from options import ModelBaseTrainConfig

@pyrallis.wrap()
def main(cfg: ModelBaseTrainConfig):
    init_seed(2025)
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = not cfg.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # ----------------------------------------
    # logging configuration
    # ----------------------------------------

    if not cfg.logging.disable_logging:
        logger = SummaryWriter(log_dir=os.path.join(cfg.logging.log_dir, "runs", cfg.logging.exp_name, cfg.logging.sub_exp_name))
        logger.add_text("Config", json.dumps(pyrallis.encode(cfg), indent=4))
    else:
        logger = None
        
    exp_dir_path = create_experiment_dir(cfg.logging.log_dir, cfg.logging.exp_name, 'cont', cfg.logging.sub_exp_name)

    # ----------------------------------------
    # model base configuration
    # ----------------------------------------
    model = load_model(cfg.task.original_model_name, num_classes=cfg.task.num_classes).to(device)
    if cfg.checkpoint_path is not None:
        model.load_state_dict(torch.load(cfg.checkpoint_path, map_location=device, weights_only=False))
    
    # ----------------------------------------
    # data loader configuration
    # ----------------------------------------
    ba_train_loader = load_images(cfg.task.data_dir, cfg.task.task_name, data_type='train', batch_size=cfg.batch_size)
    ba_test_loader = load_images(cfg.task.data_dir, cfg.task.task_name, data_type='test', batch_size=cfg.batch_size)
    clean_test_loader = load_images(cfg.task.data_dir, 'cifar10' if 'cifar10' in cfg.task.task_name else 'Tiny-Imagenet', data_type='test', batch_size=cfg.batch_size)

    # ----------------------------------------
    # opimization configuration
    # ----------------------------------------
    criterion = load_task_loss(cfg.task)
    optimizer = load_optimizer(model.parameters(), cfg.optim)

    # ----------------------------------------
    # each epoch
    # ----------------------------------------
    for epoch in tqdm(range(cfg.epochs), desc='Epoch'):
        loss, acc = train(ba_train_loader, model, criterion, optimizer, device)
        clean_acc, asr, ra = 0, 0, 0
        if 'src' in cfg.logging.sub_exp_name:
            clean_acc, asr, ra = test_ba(ba_test_loader, clean_test_loader, model, criterion, device)
        else:
            clean_acc = test(clean_test_loader, model, criterion, device)
            
        loss_acc_dict = {
            'original model training loss': loss.avg,
            'original model training acc': acc.avg,
            'original model clean test acc': clean_acc,
            'original model asr': asr,
            'original model ra': ra
        }
        
        log_scalar_dict(loss_acc_dict,
                        title='model_base_cont',
                        iteration=epoch,
                        logger=logger)
        # ----------------------------------------
        # save best model
        # ----------------------------------------
        torch.save(model.state_dict(), os.path.join(exp_dir_path, f'model_cont{epoch+1}.pt'))

    print('MODEL DIR', exp_dir_path)


def train(train_loader, model, criterion, optimizer, device):
    model.train()

    loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(total=len(train_loader), step=20, prefix='Train',
                             meters=[loss_meter, acc_meter])

    for i, samples in enumerate(train_loader):
        inputs, labels = samples[0], samples[1]
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        acc, acc5 = accuracy(outputs, labels, topk=(1, 5))

        loss_meter.update(loss.item(), inputs.size(0))
        acc_meter.update(acc.item(), inputs.size(0))

        optimizer.zero_grad()  # 1
        loss.backward()  # 2
        optimizer.step()  # 3

        progress.display(i)

    return loss_meter, acc_meter

def test(test_loader, model, criterion, device):
    model.eval()

    loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')

    for i, samples in enumerate(test_loader):
        inputs, labels = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            acc, acc5 = accuracy(outputs, labels, topk=(1, 5))

            loss_meter.update(loss.item(), inputs.size(0))
            acc_meter.update(acc.item(), inputs.size(0))

    print('clean_acc', acc_meter.avg)
    return acc_meter.avg

def test_ba(ba_test_loader, clean_test_loader, model, criterion, device):
    model.eval()

    ba_predict_list = []
    ba_label_list = []
    ba_original_target_list = []
    
    clean_predict_list = []
    clean_label_list = []
    
    for i, samples in enumerate(ba_test_loader):
        inputs, labels, batch_original_index, batch_poison_or_not, batch_original_target = samples
        inputs = inputs.to(device)
        labels = labels.to(device)
        batch_original_target = batch_original_target.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            ba_predict_list.append(outputs)
            ba_label_list.append(labels)
            ba_original_target_list.append(batch_original_target)
            
    for i, samples in enumerate(clean_test_loader):
        inputs, labels = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            clean_predict_list.append(outputs)
            clean_label_list.append(labels)
            
    ba_predict = torch.cat(ba_predict_list, dim=0)
    ba_label = torch.cat(ba_label_list, dim=0)
    ba_original_target = torch.cat(ba_original_target_list, dim=0)
    clean_predict = torch.cat(clean_predict_list, dim=0)
    clean_label = torch.cat(clean_label_list, dim=0)
    
    clean_acc = all_acc(clean_predict.argmax(dim=-1), clean_label)
    asr = all_acc(ba_predict.argmax(dim=-1), ba_label)
    ra = all_acc(ba_predict.argmax(dim=-1), ba_original_target)
    
    print('clean_acc', clean_acc)
    print('asr', asr)
    print('ra', ra)
    
    return clean_acc, asr, ra

def all_acc(preds:torch.Tensor,
        labels:torch.Tensor,):
    return preds.eq(labels).sum().item() / len(preds)

def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
