import gc
import ast
import logging
import os
import math
import time
import argparse
from easydict import EasyDict as edict
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import random

from datasets.affect.data_loader import MMDataLoader
from model import imder, dicmor
from training_structures.trainer import Trainer
from training_structures.utils.functions import ConfigParser

logger = logging.getLogger('MMSA')

parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, 
                       default='configs/config_generate.json'
                       )
parser.add_argument('--model_name', type=str, 
                       default='imder'
                       )
parser.add_argument('--dataset_name', type=str, 
                       default='mosi'
                       )
parser.add_argument('--train_mode', type=str, 
                       default='regression'
                       )
parser.add_argument('--num_workers', type=int, 
                       default=4
                    )
parser.add_argument('--cur_seed', type=int, 
                       default=42
                    )
parser.add_argument("--no_vision", action="store_true")
parser.add_argument("--no_audio", action="store_true")
parser.add_argument("--no_text", action="store_true")
parser.add_argument("--use_modal_conflict", action="store_true")
parser.add_argument('--num_modal', type=int,
                    default=1)
parser.add_argument('--ava_modal_idx', type=int, nargs='+', default=[0])
parser.add_argument('--available_size', type=float, default=0.1)
parser.add_argument('--cross_modal_generation_loss_weight', type=float, default=0.1)
parser.add_argument('--modality_conflict_loss_weight', type=float, default=0.01)
parser.add_argument("--generate", action="store_true")
parser.add_argument('--train_ratio', type=float, 
                       default=0.2
                       )
parser.add_argument('--log_dir', type=str, 
                       default='./logs'
                       )
parser.add_argument('--save_dir', type=str, 
                       default='results/saved_models'
                       )
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()


def _set_logger(log_dir, model_name, dataset_name, verbose_level):
    # base logger
    log_file_path = Path(log_dir) / f"{model_name}-{dataset_name}.log"
    logger = logging.getLogger('MMSA')
    logger.setLevel(logging.DEBUG)

    # file handler
    fh = logging.FileHandler(log_file_path)
    fh_formatter = logging.Formatter('%(asctime)s - %(name)s [%(levelname)s] - %(message)s')
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(fh_formatter)
    logger.addHandler(fh)

    # stream handler
    stream_level = {0: logging.ERROR, 1: logging.INFO, 2: logging.DEBUG}
    ch = logging.StreamHandler()
    ch.setLevel(stream_level[verbose_level])
    ch_formatter = logging.Formatter('%(name)s - %(message)s')
    ch.setFormatter(ch_formatter)
    logger.addHandler(ch)

    return logger

def set_seed(seed):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def main(args):
    logger = _set_logger(args.log_dir, args.model_name, args.dataset_name, verbose_level=1)
    set_seed(args.cur_seed)
    
    # update config.json
    config = ConfigParser(args.config_path)
    hyparams = config.get_params(args.dataset_name, args.model_name)
    args = edict(vars(args))  
    args.update(hyparams)

    logger.info(f"Arguments: {args}")
    
    dataloader = MMDataLoader(args, train_ratio=args.train_ratio, use_video=not args.no_vision, use_audio=not args.no_audio, use_text=not args.no_text)
    
    #model = getattr(imder, 'IMDER')(args)
    model = getattr(dicmor, 'DICMOR')(args)
    model = model.to(args.device)

    args['model_save_path'] = Path(args.save_dir) / f"best_{args['model_name']}-{args['dataset_name']}-{args.train_ratio}.pth"
    trainer = Trainer(args)

    epoch_results= trainer.do_train(model, dataloader, return_epoch_results=False)
    model.load_state_dict(torch.load(args.model_save_path))

    results = trainer.do_test(model, dataloader['test'], mode="TEST")

    del model
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(1)

    return results
    
    
if __name__ == "__main__":
    set_seed(args.seed)
    main(args)
