import os
import copy
import time
import itertools
import numpy as np
import torch
import datasets
import models
import argparse
from objectives import compute_batch_loss
from instrumentation import train_logger

from tools import plot_stats

__author__ = 'Anonymous'
__docformat__ = 'reStructuredText'
__all__ = ['']



def run_adaptation(model, P, Z, logger, epoch, phase):

    '''
    Run one training phase.

    Parameters
    model: Model to train.
    P: Dictionary of parameters, which completely specify the training procedure.
    Z: Dictionary of temporary objects used during training.
    logger: Object used to track various metrics during training.
    epoch: Integer index of the current epoch.
    phase: String giving the phase name
    '''

    assert phase == 'train'
    model.train()

    target_iter = itertools.cycle(Z['dataloaders']['target'])

    for src_data in Z['dataloaders']['source']:

        batch = {}

        tgt_data = next(target_iter)

        src_img, src_labels = src_data
        tgt_img, _ = tgt_data

        batch['num_src_samples'] = src_img.shape[0]
        batch['images'] = torch.vstack([src_img, tgt_img]).to(Z['device'])

        batch['labels'] = src_labels.to(Z['device'])
        batch['labels_np'] = src_labels.clone().cpu().numpy()

        # forward pass:
        Z['optimizer'].zero_grad()

        batch['logits'], batch['latent_feat'] = model(batch['images'], P['act_flag'])
        batch['preds'] = torch.softmax(batch['logits'], dim=1)
        batch['preds_np'] = batch['preds'].clone().detach().cpu().numpy()

        batch = compute_batch_loss(batch, P)

        # backward pass:
        batch['loss_tensor'].backward()
        Z['optimizer'].step()

        # save current batch data:
        logger.update_phase_data(batch, P, phase)

    return batch['lamda']


def run_source_only(model, P, Z, logger, epoch, phase):

    '''
    Run one training phase.

    Parameters
    model: Model to train.
    P: Dictionary of parameters, which completely specify the training procedure.
    Z: Dictionary of temporary objects used during training.
    logger: Object used to track various metrics during training.
    epoch: Integer index of the current epoch.
    phase: String giving the phase name
    '''

    assert phase == 'train'
    model.train()

    # desc = '[{}/{}]{}'.format(epoch, P['num_epochs'], phase.rjust(8, ' '))
    for src_batch in Z['dataloaders']['source']:
        batch = {}

        src_img, src_labels = src_batch

        batch['images'] = src_img.to(Z['device'])

        batch['labels_np'] = src_labels.clone().cpu().numpy()
        batch['labels'] = src_labels.to(Z['device'])

        # forward pass:
        Z['optimizer'].zero_grad()

        batch['logits'], _ = model(batch['images'], P['act_flag'])

        batch['preds'] = torch.softmax(batch['logits'], dim=1)
        batch['preds_np'] = batch['preds'].clone().detach().cpu().numpy()

        batch = compute_batch_loss(batch, P)

        # backward pass:
        batch['loss_tensor'].backward()
        Z['optimizer'].step()

        # save current batch data:
        logger.update_phase_data(batch, P, phase)


def run_evaluation(model, P, Z, logger, epoch, phase):

    '''
    Run one evaluation phase.

    Parameters
    model: Model to train.
    P: Dictionary of parameters, which completely specify the training procedure.
    Z: Dictionary of temporary objects used during training.
    logger: Object used to track various metrics during training.
    epoch: Integer index of the current epoch.
    phase: String giving the phase name
    '''

    # assert phase in ['val', 'test']
    model.eval()
    # desc = '[{}/{}]{}'.format(epoch, P['num_epochs'], phase.rjust(8, ' '))
    for batch_data in Z['dataloaders']['test']:
        batch = {}
        # move data to GPU:
        batch['images'], batch['labels'] = batch_data[0].to(Z['device']), batch_data[1].to(Z['device'])

        batch['labels_np'] = batch['labels'].clone().cpu().numpy()

        # forward pass:
        with torch.no_grad():
            batch['logits'], _ = model(batch['images'], P['act_flag'])
            if batch['logits'].dim() == 1:
                batch['logits'] = torch.unsqueeze(batch['logits'], 0)
            batch['preds'] = torch.softmax(batch['logits'], dim=1)
            batch['preds_np'] = batch['preds'].clone().detach().cpu().numpy()

            # batch['loss_np'] = -1
            # batch['reg_loss_np'] = -1
        # save current batch data:
        logger.update_phase_data(batch, P, phase)


def train(model, P, Z):

    '''
    Train the model.

    Parameters
    P: Dictionary of parameters, which completely specify the training procedure.
    Z: Dictionary of temporary objects used during training.
    '''

    # best_weights_f = copy.deepcopy(model.state_dict())
    logger = train_logger(P)  # initialize logger

    acc = {key: [] for key in P['phase']}
    cls_loss = {key: [] for key in P['phase']}
    da_loss = {key: [] for key in ['train']}

    for epoch_idx in range(0, P['num_epochs']):
        print('\nstart epoch [{}/{}] ...'.format(epoch_idx + 1, P['num_epochs']))
        P['epoch'] = epoch_idx + 1

        for phase in P['phase']:
            # reset phase metrics:
            logger.reset_phase_data()

            # run one phase:
            t_init = time.time()
            if phase == 'train':
                if P['exp_mode'] == 'source_only':
                    run_source_only(model, P, Z, logger, P['epoch'], phase)
                else:
                    last_iter_lamda = run_adaptation(model, P, Z, logger, P['epoch'], phase)
            else:
                run_evaluation(model, P, Z, logger, P['epoch'], phase)

            # save end-of-phase metrics:
            logger.compute_phase_metrics(phase, P['epoch'], P['exp_mode'])

            # print epoch status:
            if P['exp_mode'] == 'adaptation' and phase == 'train':
                logger.report(t_init, time.time(), phase, P['epoch'], last_iter_lamda)
            else:
                logger.report(t_init, time.time(), phase, P['epoch'])

            acc[phase].append(logger.logs['metrics'][phase][P['epoch']]['acc'])
            cls_loss[phase].append(logger.logs['metrics'][phase][P['epoch']]['class_loss'])
            if P['exp_mode'] == 'adaptation' and phase == 'train':
                da_loss[phase].append(logger.logs['metrics'][phase][P['epoch']]['domain_loss'])

    plot_stats(res_dict=acc, ylabel='Accuracy', dir=P['result_dir'], filename='acc.png')
    plot_stats(res_dict=cls_loss, ylabel='Cross Entropy', dir=P['result_dir'], filename='cls_loss.png')
    if P['exp_mode'] == 'adaptation':
        plot_stats(res_dict=da_loss, ylabel='Adaptation Loss', dir=P['result_dir'], filename='da_loss.png')

    return P, model, logger


def initialize_training_run(P):

    '''
    Set up for model training.

    Parameters
    P: Dictionary of parameters, which completely specify the training procedure.
    '''

    np.random.seed(P['seed'])

    Z = {}
    Z['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # datasets
    Z['datasets'] = {}

    Z['datasets']['source'] = datasets.get_data(P['source_domain'], P)
    Z['datasets']['target'] = datasets.get_data(P['target_domain'], P)
    print('Source domain: {}'.format(P['source_domain']))
    print('Source samples: {} total'.format(len(Z['datasets']['source'])))
    print('Target domain: {}'.format(P['target_domain']))
    print('Target samples: {} total\n'.format(len(Z['datasets']['target'])))

    # save dataset-specific parameters:
    P['num_classes'] = len(Z['datasets']['source'].classes)
    print('Num classes: {}'.format(P['num_classes']))

    Z['dataloaders'] = {}
    # Create source dataloader
    source_batch_size = min(P['bsize'], len(Z['datasets']['source']))
    Z['dataloaders']['source'] = torch.utils.data.DataLoader(
        dataset=Z['datasets']['source'], 
        drop_last=True,
        shuffle=True,
        batch_size=source_batch_size, 
        num_workers=P['num_workers']
    )
    print(f"\nSource dataloader:")
    print(f"Batch size: {source_batch_size}")
    print(f"Number of batches: {len(Z['dataloaders']['source'])}")

    # Only create the target dataloader if not in source_only mode
    if P['exp_mode'] != 'source_only':
        target_batch_size = min(P['bsize'], len(Z['datasets']['target']))
        Z['dataloaders']['target'] = torch.utils.data.DataLoader(
            dataset=Z['datasets']['target'], 
            drop_last=True,
            shuffle=True,
            batch_size=target_batch_size, 
            num_workers=P['num_workers']
        )
        print(f"\nTarget dataloader:")
        print(f"Batch size: {target_batch_size}")
        print(f"Number of batches: {len(Z['dataloaders']['target'])}")

    # Create the test dataloader (always uses target dataset)
    test_batch_size = min(P['bsize'], len(Z['datasets']['target']))
    Z['dataloaders']['test'] = torch.utils.data.DataLoader(
        dataset=Z['datasets']['target'], 
        shuffle=False, 
        drop_last=False,
        batch_size=test_batch_size, 
        num_workers=P['num_workers']
    )
    print(f"\nTest dataloader:")
    print(f"Batch size: {test_batch_size}")
    print(f"Number of batches: {len(Z['dataloaders']['test'])}\n")

    model = models.ImageClassifier(P)

    # optimization objects:
    f_conv_params = [param for param in list(model.feature_extractor.layer4.parameters()) if param.requires_grad]
    f_fc_params = [param for param in list(model.feature_extractor.fc.parameters())]
    c_params = [param for param in list(model.linear_classifier.parameters())]

    Z['optimizer'] = torch.optim.Adam([
                {'params': f_conv_params},
                {'params': f_fc_params, 'lr': P['fc_lr']},
                {'params': c_params, 'lr': P['fc_lr']}
            ], lr=P['global_lr']
    )

    return P, Z, model


def execute_training_run(P):

    '''
    Initialize, run the training process, and save the results.

    Parameters
    P: Dictionary of parameters, which completely specify the training procedure.
    '''

    P, Z, model = initialize_training_run(P)
    model.to(Z['device'])

    P, model, logger = train(model, P, Z)

    final_logs = logger.get_logs()
    # model.load_state_dict(best_weights_f)

    return model.feature_extractor, model.linear_classifier, final_logs


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='GeoAdapt -- 2025')
    parser.add_argument('-j', '--job_id', type=int)
    parser.add_argument('-ps', '--pytorch_seed', default=1, type=int)
    parser.add_argument('-dr', '--data_root', default='./data/', type=str)
    parser.add_argument('-d', '--dataset', choices=['office31', 'visda2017'], default='visda2017', type=str)
    parser.add_argument('-s', '--source_domain', default='train', type=str)
    parser.add_argument('-t', '--target_domain', default='test', type=str)
    parser.add_argument('-e', '--exp_mode', default='adaptation', choices=['source_only', 'adaptation'], type=str)
    parser.add_argument('-am', '--adapt_method', default='geo_adapt', choices=['ddc', 'coral', 'log_coral', 'cmd', 'homm', 'geo_adapt'], type=str)
    parser.add_argument('-dm', '--dist_metric', default='airm', choices=['hilbert', 'log_euclidean', 'airm'], type=str)
    parser.add_argument('-fd', '--feat_dim', type=int, default=25)
    parser.add_argument('-l', '--lamda', type=float, default=0.1)
    parser.add_argument('-bs', '--batch_size', type=int, default=861)
    parser.add_argument('-ep', '--epochs', type=int, default=50)
    parser.add_argument('-dt', '--det_thr', type=float, default=1e-8)
    parser.add_argument('-hm', '--highest_moment', type=int, default=2)
    parser.add_argument('-rs', '--result_folder', default='results/', type=str)
    parser.add_argument('-ft', '--fine_tune', default='last_block', choices=['last_layer', 'last_block'], type=str)
    args = parser.parse_args()

    # System parameters:
    if torch.cuda.is_available():
        print('CUDA is available!')
        print('Number of available GPUs: {}\n'.format(torch.cuda.device_count()))
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    else:
        print('CUDA is NOT available.\n')

    config = {}

    config['job_id'] = str(args.job_id)
    print('Job ID: {}'.format(config['job_id']))
    # Top-level parameters:
    config['dataset'] = args.dataset
    # print('Dataset: {}'.format(config['dataset']))

    config['exp_mode'] = args.exp_mode
    print('Experiment mode: {}'.format(config['exp_mode']))

    config['adapt_method'] = args.adapt_method

    config['highest_moment'] = args.highest_moment

    config['source_domain'] = args.source_domain
    config['target_domain'] = args.target_domain

    config['dist_metric_type'] = args.dist_metric
    config['det_thr'] = args.det_thr

    config['data_root'] = args.data_root
    config['result_dir'] = os.path.join(args.result_folder, config['job_id'])
    if not os.path.exists(config['result_dir']):
        os.makedirs(config['result_dir'])

    config['user_lamda'] = args.lamda

    if config['dataset'] in ['office31', 'visda2017']:
        config['val_ext'] = False
        config['phase'] = ['train', 'test']
    else:
        config['val_ext'] = True
        config['phase'] = ['train', 'val', 'test']

    config['pytorch_seed'] = args.pytorch_seed
    print('PyTorch seed: {}'.format(str(config['pytorch_seed'])))
    torch.manual_seed(config['pytorch_seed'])
    torch.cuda.manual_seed_all(config['pytorch_seed'])

    # Optimization & method parameters:
    if config['dataset'] == 'office31':
        config['bsize'] = args.batch_size
        config['global_lr'] = 3e-5
        config['fc_lr'] = 3e-4
        config['warmup_epoch'] = 0
        config['img_resize'] = 224
        config['lamda_type'] = 'fixed' #['fixed', 'epoch_based']

    elif config['dataset'] == 'visda2017':
        config['bsize'] = args.batch_size
        config['global_lr'] = 3e-5
        config['fc_lr'] = 3e-4
        config['warmup_epoch'] = 0
        config['img_resize'] = 224
        config['lamda_type'] = 'fixed'

    # Additional parameters:
    config['seed'] = 1200  # overall numpy seed
    config['num_workers'] = 8
    config['stop_metric'] = 'acc'  # metric used to select the best epoch

    if config['exp_mode'] == 'adaptation' and config['adapt_method'] == 'geo_adapt':
        config['act_flag'] = False
    else:
        config['act_flag'] = True

    # training parameters:
    config['num_epochs'] = args.epochs
    config['arch'] = 'resnet50'
    config['use_pretrained'] = True
    config['resnet50_weights'] = 'IMAGENET1K_V1'  # 'DEFAULT'
    config['freeze_feature_extractor'] = False
    config['feat_dim'] = args.feat_dim
    print('Latent feature size: {}'.format(str(config['feat_dim'])))

    config['fine_tune'] = args.fine_tune

    if config['exp_mode'] == 'adaptation':
        print('\nDA mode: {}'.format(config['adapt_method']))
        print('Lambda: {:.2e}'.format(config['user_lamda']))

        if config['adapt_method'] == 'geo_adapt':
            print('\nDistance metric: {}'.format(config['dist_metric_type']))
            print('Determinant threshold: {:.1e}'.format(config['det_thr']))

        elif config['adapt_method'] in ['cmd', 'homm']:
            print('\nHighest moment: {}'.format(config['highest_moment']))

    # run the training process:
    print('\n[{}] {} --> {}'.format(config['dataset'], config['source_domain'], config['target_domain']))
    (feature_extractor, linear_classifier, logs) = execute_training_run(config)