import torch
import torch.nn as nn
from tqdm import tqdm
import random
import numpy as np
from omegaconf import OmegaConf
from AutoEncoder_models import MLPAE
from create_data_loader import single_cell_data_loader
from train_test import train_and_eval_per_epoch
from utils import l2_norm_model_weights
import pandas as pd
import os

config = OmegaConf.load('config.yaml')
seed = config.seed
batch_size = config.hyperparameters.batch_size
lr = config.hyperparameters.lr
opt = config.hyperparameters.opt
epochs = config.hyperparameters.epochs
noise_per = config.noise.noise_per
snr_db = config.noise.snr_db
n_features = config.n_features
n_samples = config.n_samples
source = config.hyperparameters.source
target = config.hyperparameters.target
scenario = config.scenario

exp_name = 'cells_feature_noise'
np.random.seed(seed)
random.seed(seed)
train_loader, test_loader = single_cell_data_loader(n_samples, n_features, batch_size, source, target, scenario,
                                                    snr_db, noise_per)
hidden_dims = list(range(10, 510, 10)) + list(range(550, 3050, 50))
for latent_dim in [20]:#range(10, 60, 10):
    for hidden_dim in hidden_dims:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        print(f'hidden_dim={hidden_dim}')
        print(f'latent_dim={latent_dim}')
        print(f'seed={seed}')
        if scenario == 'domain_shift':
            print(f'source={source}')
            print(f'target={target}')
        elif scenario == 'sample_noise' or scenario == 'feature_noise':
            print(f'noise_per={noise_per}')
            print(f'snr_db={snr_db}')
        else:
            raise ValueError('Scenario not implemented')

        model = MLPAE(input_dim=train_loader.dataset[0][0].shape[0], latent_dim=latent_dim, hidden_dim=hidden_dim,
                       n_hidden_layers=0, final_activation=nn.Identity())

        total_params = sum(param.numel() for param in model.parameters())
        print(f'total_params={total_params}')
        run_results = []
        min_train_loss = np.inf
        min_test_loss = np.inf
        for epoch in tqdm(range(epochs)):
            train_loss, test_loss, _, _ = train_and_eval_per_epoch(train_loader, test_loader, model, opt, lr)
            min_train_loss = min(min_train_loss, train_loss)
            min_test_loss = min(min_test_loss, test_loss)
            l2_norm = l2_norm_model_weights(model)
            if (epoch + 1) % 10 == 0:
                print(f'Epoch: {epoch + 1}, Train loss: {train_loss}, test loss: {test_loss}')

            results_per_epoch = {
                'train_loss': train_loss,
                'test_loss': test_loss,
                'epoch': epoch,
                'n_model_params': total_params,
                'hidden_dim': hidden_dim,
                'latent_dim': latent_dim,
                'weights_l2_norm': l2_norm,
                'min_train_loss': min_train_loss,
                'min_test_loss': min_test_loss,
                'seed': seed,
                'n_features': len(train_loader.dataset[0][0]),
                'n_samples_batch_1': n_samples,
                'batch_size': batch_size
            }
            if scenario == 'domain_shift':
                results_per_epoch['source'] = source
                results_per_epoch['target'] = target
            elif scenario == 'sample_noise' or scenario == 'feature_noise':
                results_per_epoch['noise_per'] = noise_per
                results_per_epoch['snr_db'] = snr_db

            run_results.append(results_per_epoch)

        run_results = pd.DataFrame(run_results)
        if scenario == 'domain_shift':
            file_name = f'seed_{seed}_source_{source}_target_{target}.csv'
        elif scenario == 'sample_noise' or scenario == 'feature_noise':
            file_name = f'seed_{seed}_noise_per_{noise_per}_snr_{snr_db}.csv'

        if not os.path.exists(f'results/{exp_name}'):  # create the folder with the experiment name if it doesn't exist
            os.makedirs(f'results/{exp_name}')
        for root, dirs, files in os.walk(os.getcwd() + f'/results/{exp_name}'):
            if file_name not in files:  # if the directory is empty create the csv file
                run_results.to_csv(f'results/{exp_name}/{file_name}', index=False)
            else:  # if the csv file exists, read it, add to it new information and save it again
                results = pd.read_csv(f'results/{exp_name}/{file_name}')
                results = pd.concat([results, run_results])
                results.to_csv(f'results/{exp_name}/{file_name}', index=False)
