import warnings
warnings.filterwarnings("ignore")

import torch
import sys
import os
import argparse
import gc
import time
import random
import numpy as np
import pandas as pd
sys.path.append(os.getcwd())
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from training_structures.multimodal import train, test # noqa
from model.common_models import Identity  # noqa
from datasets.affect.data_loader import MMDataLoader
from model.mult import MULTModel   # noqa
from training_structures.utils.functions import ConfigParser, get_subset_distribution
from easydict import EasyDict as edict

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, 
                       default='config_complete.json'
                       )
parser.add_argument('--dataset_name', type=str, 
                       default='mosi'
                       )
parser.add_argument('--model_name', type=str, 
                       default='multimodaltransformer'
                       )
parser.add_argument('--train_ratio', type=float, 
                       default=0.2
                       )
parser.add_argument("--no_vision", action="store_true")
parser.add_argument("--no_audio", action="store_true")
parser.add_argument("--no_text", action="store_true")
parser.add_argument('--num_modal', type=int, 
                       default=3,
                       )
parser.add_argument('--batch_size', type=int, 
                       default=32
                       )
parser.add_argument('--num_workers', type=int, 
                       default=4
                    )
parser.add_argument('--embed_dim', type=int, 
                       default=512,
                       )
parser.add_argument('--train_mode', type=str, 
                       default='regression',
                       choices=['regression', 'classification']
                       )
parser.add_argument('--epoch', type=int, 
                       default=100 
                       )
parser.add_argument('--lr', type=float, 
                        default=1e-4
                       )
parser.add_argument('--weight_decay', type=float, 
                        default=1e-4
                       )
parser.add_argument('--early_stop', type=int,
                        default=8
                       )
parser.add_argument('--patience', type=int, 
                        default=10
                       )
parser.add_argument('--clip_val', type=float, 
                        default=1.0
                       )
parser.add_argument('--save_dir', type=str, 
                       default='results/saved_models'
                       )
parser.add_argument('--device', type=str, default='cuda'
                    )
parser.add_argument('--seeds', type=int, nargs='+', 
                    default=[1234, 2234, 3234, 4234, 5234])
args = parser.parse_args()


def main(args):
    config = ConfigParser(args.config_path)
    hyparams = config.get_params(args.dataset_name, args.model_name)
    args = edict(vars(args))  
    args.update(hyparams)

    args.use_bert = False if args.no_text else True

    dataloader = MMDataLoader(args, train_ratio=args.train_ratio, 
                              use_video=not args.no_vision, 
                              use_audio=not args.no_audio, 
                              use_text=not args.no_text)
    get_subset_distribution(dataloader['train']) # check label distribustion
    
    encoders = [Identity().cuda(), Identity().cuda(), Identity().cuda()]

    used_modalities = sum([not args.no_vision, not args.no_audio, not args.no_text])
    features_list = []
    if not args.no_text:
        features_list.append(args.feature_dims[0])  # text features
    if not args.no_audio:
        features_list.append(args.feature_dims[1])   # audio features
    if not args.no_vision:
        features_list.append(args.feature_dims[2])  # vision features

    fusion = MULTModel(used_modalities, features_list, args).cuda()
    head = Identity().cuda()
    
    multimodal_dir = os.path.join(args.save_dir, 'multimodal')
    os.makedirs(multimodal_dir, exist_ok=True)

    # multiple seeds
    model_results = []
    for seed in args.seeds:
        args.seed = seed
        print(f"Arguments: {args}")
        set_seed(args.seed)

        model_name = f"best_{not args.no_vision}_{not args.no_audio}_{not args.no_text}_{args.train_ratio}_{args.seed}.pt"
        saved_model= os.path.join(multimodal_dir, model_name)
        
        print("Training:")
        train(encoders, fusion, head, dataloader['train'], dataloader['valid'], args.epoch, task=args.train_mode, optimtype=torch.optim.AdamW, early_stop=args.early_stop, patience=args.patience, lr=args.lr, clip_val=args.clip_val, save=saved_model, weight_decay=args.weight_decay, objective=torch.nn.L1Loss(), num_modal=args.num_modal)

        print("Testing:")
        model = torch.load(saved_model).cuda()

        test_result = test(model, dataloader['test'], criterion=torch.nn.L1Loss(), task=args.train_mode, dataset=args.dataset_name, num_modal=args.num_modal)
        model_results.append(test_result)
    
        # save results
        summary_file = os.path.join(multimodal_dir, 'all_multimodal_experiments_summary.txt')
        with open(summary_file, 'a') as f:
            f.write(f"{model_name}: {test_result}\n")
    # save results to csv
    criterions = list(model_results[0].keys())
    csv_file = os.path.join(multimodal_dir, f"{args.dataset_name}_{args.train_ratio}.csv")
    if os.path.isfile(csv_file):
        df = pd.read_csv(csv_file)
    else:
        df = pd.DataFrame(columns=["Model"] + criterions)
    # save results
    res = [args.model_name]
    for c in criterions:
        values = [r[c] for r in model_results]
        mean = round(np.mean(values) * 100, 2)
        std = round(np.std(values) * 100, 2)
        res.append((mean, std))
    df.loc[len(df)] = res
    df.to_csv(csv_file, index=None)
    print(f"Results saved to {csv_file}.")
    
    del model
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(1)

if __name__ == "__main__":
  main(args)
