import torch
import torch.nn as nn
from tqdm import tqdm
import random
import numpy as np
from omegaconf import OmegaConf
from AutoEncoder_models import MLPAE
from train_test import train_and_eval_per_epoch, train_and_eval_per_epoch_anomalies
from utils import l2_norm_model_weights
import pandas as pd
import os
from create_data_loader import gaussian_data_loader_non_linear


config = OmegaConf.load('config.yaml')
seed = config.seed
batch_size = config.hyperparameters.batch_size
lr = config.hyperparameters.lr
opt = config.hyperparameters.opt
epochs = config.hyperparameters.epochs
model_check_point = config.model_check_point
num_workers = config.num_workers
n_samples = config.n_samples
higher_dim = config.higher_dim
noise_per = config.noise.noise_per
snr_db = config.noise.snr_db
scenario = config.scenario
shift = config.hyperparameters.shift
cuda = config.cuda


exp_name = 'more_experiments/non_linear_subspace_sample_noise'
for seed in [0]:
    for latent_dim in [5, 10, 15, 20, 25, 30, 35, 40, 45]:
        for hidden_dim in range(4, 504, 4):
            torch.manual_seed(seed)
            random.seed(seed)
            np.random.seed(seed)
            if scenario == 'anomalies':
                train_loader, test_loader, anomaly_test_loader = gaussian_data_loader_non_linear(n_samples, higher_dim,
                                                                                                 20, batch_size,
                                                                                                 scenario, noise_per,
                                                                                                 snr_db)
            else:
                train_loader, test_loader = gaussian_data_loader_non_linear(n_samples, higher_dim, 20, batch_size,
                                                                            scenario, noise_per, snr_db, shift)

            print(f'num train={len(train_loader.dataset)}')
            print(f'hidden_dim={hidden_dim}')
            print(f'latent_dim={latent_dim}')
            if scenario == 'anomalies':
                print(f'anomaly_per={noise_per}')
            else:
                print(f'noise_per={noise_per}')
            print(f'snr_db={snr_db}')
            print(f'seed={seed}')
            if scenario == 'domain_shift':
                print(f'shift={shift}')

            model = MLPAE(input_dim=train_loader.dataset[0][0].shape[0], latent_dim=latent_dim, hidden_dim=hidden_dim,
                           n_hidden_layers=0, final_activation=nn.Identity())

            total_params = sum(param.numel() for param in model.parameters())
            print(f'total_params={total_params}')
            run_results = []
            min_train_loss = np.inf
            min_test_loss = np.inf
            if scenario == 'anomalies':
                min_test_clean_loss = np.inf
                min_test_anomaly_loss = np.inf
                min_roc_auc = np.inf
            for epoch in tqdm(range(epochs)):
                if scenario == 'anomalies':
                    train_loss, test_loss, anomaly_loss, roc_auc = train_and_eval_per_epoch_anomalies(train_loader,
                                                                                                      test_loader,
                                                                                                      anomaly_test_loader,
                                                                                                      model, opt, lr)
                else:
                    train_loss, test_loss, _, _ = train_and_eval_per_epoch(train_loader, test_loader, model, opt, lr)
                min_train_loss = min(min_train_loss, train_loss)
                min_test_loss = min(min_test_loss, test_loss)
                if scenario == 'anomalies':
                    min_test_clean_loss = min(min_test_clean_loss, test_loss)
                    min_test_anomaly_loss = min(min_test_anomaly_loss, anomaly_loss)
                    min_roc_auc = min(min_roc_auc, roc_auc)
                l2_norm = l2_norm_model_weights(model)
                if (epoch + 1) % 10 == 0:
                    if scenario == 'anomalies':
                        print(f'Epoch: {epoch + 1}, Train loss: {round(train_loss, 3)}, clean test loss: '
                              f'{round(test_loss, 3)}, anomaly test loss: {round(anomaly_loss, 3)}, roc auc: '
                              f'{round(roc_auc, 3)}')
                    else:
                        print(f'Epoch: {epoch + 1}, Train loss: {train_loss}, test loss: {test_loss}')

                results_per_epoch = {
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'epoch': epoch,
                    'n_model_params': total_params,
                    'hidden_dim': hidden_dim,
                    'latent_dim': latent_dim,
                    'weights_l2_norm': l2_norm,
                    'min_train_loss': min_train_loss,
                    'min_test_loss': min_test_loss,
                    'seed': seed,
                    'noise_per': noise_per,
                    'snr_db': snr_db,
                    'n_samples': n_samples,
                    'batch_size': batch_size
                }
                if scenario == 'domain_shift':
                    results_per_epoch['shift'] = shift
                elif scenario == 'anomalies':
                    results_per_epoch['anomaly_per'] = results_per_epoch.pop('noise_per')
                    results_per_epoch['clean_test_loss'] = results_per_epoch.pop('test_loss')
                    results_per_epoch['anomaly_test_loss'] = anomaly_loss
                    results_per_epoch['roc_auc'] = roc_auc

                run_results.append(results_per_epoch)

            run_results = pd.DataFrame(run_results)
            if scenario == 'sample_noise' or scenario == 'feature_noise':
                file_name = f'seed_{seed}_noise_per_{noise_per}_snr_{snr_db}.csv'
            elif scenario == 'anomalies':
                file_name = f'seed_{seed}_anomaly_per_{noise_per}_snr_{snr_db}.csv'
            elif scenario == 'domain_shift':
                file_name = f'seed_{seed}_noise_per_{noise_per}_snr_{snr_db}_shift_{shift}.csv'

            if not os.path.exists(f'results/{exp_name}'):  # create the folder with the experiment name if it doesn't exist
                os.makedirs(f'results/{exp_name}')
            for root, dirs, files in os.walk(os.getcwd() + f'/{exp_name}'):
                if file_name not in files:  # if the directory is empty create the csv file
                    run_results.to_csv(f'{exp_name}/{file_name}', index=False)
                else:  # if the csv file exists, read it, add to it new information and save it again
                    results = pd.read_csv(f'{exp_name}/{file_name}')
                    results = pd.concat([results, run_results])
                    results.to_csv(f'{exp_name}/{file_name}', index=False)