import torch
from tqdm import tqdm
import random
import numpy as np
from omegaconf import OmegaConf
from AutoEncoder_models import CnnAE
from create_data_loader import mnist_data_loader
from train_test import train_and_eval_per_epoch
from utils import l2_norm_model_weights
import pandas as pd
import os

config = OmegaConf.load('config.yaml')
seed = config.seed
batch_size = config.hyperparameters.batch_size
lr = config.hyperparameters.lr
opt = config.hyperparameters.opt
epochs = config.hyperparameters.epochs
num_workers = config.num_workers
noise_per = config.noise.noise_per
scenario = config.scenario
snr_db = config.noise.snr_db
source = config.hyperparameters.source

exp_name = 'for_reviewers/mnist_sample_noise_sample_wise'
for seed in [5, 6, 7, 8, 9]:#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    for n_samples in [200, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 7000, 8000, 9000,
                      10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000]:
        for latent_dim in [500]:
            for channel in [6]:#range(1, 65, 1):
                torch.manual_seed(seed)
                random.seed(seed)
                np.random.seed(seed)

                train_loader, test_loader = mnist_data_loader(n_samples, batch_size, num_workers, scenario, noise_per, snr_db, source)

                print(f'channels={channel}')
                print(f'latent_dim={latent_dim}')
                print(f'seed={seed}')
                print(f'n_samples = {len(train_loader.dataset)}')
                if scenario == 'domain_shift':
                    print(f'source={source}')
                    if source == 'mnist':
                        print('target=mnistm')
                    else:
                        print('target=mnist')
                elif scenario == 'sample_noise' or scenario == 'feature_noise':
                    print(f'noise_per={noise_per}')
                    print(f'snr_db={snr_db}')
                else:
                    raise ValueError('Scenario not implemented')

                model = CnnAE(latent_dim=latent_dim, channels=channel)

                total_params = sum(param.numel() for param in model.parameters())
                print(f'total_params={total_params}')
                run_results = []
                min_train_loss = np.inf
                min_test_loss = np.inf
                for epoch in tqdm(range(epochs)):
                    train_loss, test_loss, _, _ = train_and_eval_per_epoch(train_loader, test_loader, model, opt, lr)
                    min_train_loss = min(min_train_loss, train_loss)
                    min_test_loss = min(min_test_loss, test_loss)
                    l2_norm = l2_norm_model_weights(model)
                    if (epoch + 1) % 10 == 0:
                        print(f'Epoch: {epoch + 1}, Train loss: {train_loss}, test loss: {test_loss}')

                    results_per_epoch = {
                        'train_loss': train_loss,
                        'test_loss': test_loss,
                        'epoch': epoch,
                        'batch_size': batch_size,
                        'n_model_params': total_params,
                        'channels': channel,
                        'latent_dim': latent_dim,
                        'weights_l2_norm': l2_norm,
                        'min_train_loss': min_train_loss,
                        'min_test_loss': min_test_loss,
                        'seed': seed,
                        'noise_per': noise_per,
                        'snr_db': snr_db,
                        'n_samples': n_samples
                    }
                    run_results.append(results_per_epoch)

                run_results = pd.DataFrame(run_results)

                if scenario == 'domain_shift':
                    if source == 'mnist':
                        file_name = f'seed_{seed}_source_{source}_target_mnistm_latent_{latent_dim}_n_samples_{len(train_loader.dataset)}.csv'
                    else:
                        file_name = f'seed_{seed}_source_{source}_target_mnist_latent_{latent_dim}_n_samples_{len(train_loader.dataset)}.csv'
                elif scenario == 'sample_noise' or scenario == 'feature_noise':
                    file_name = f'seed_{seed}_noise_per_{noise_per}_snr_{snr_db}_n_samples_{len(train_loader.dataset)}.csv'

                if not os.path.exists(f'results/{exp_name}'):  # create the folder with the experiment name if it doesn't exist
                    os.makedirs(f'results/{exp_name}')
                for root, dirs, files in os.walk(os.getcwd() + f'/{exp_name}'):
                    if file_name not in files:  # if the directory is empty create the csv file
                        run_results.to_csv(f'{exp_name}/{file_name}', index=False)
                    else:  # if the csv file exists, read it, add to it new information and save it again
                        results = pd.read_csv(f'{exp_name}/{file_name}')
                        results = pd.concat([results, run_results])
                        results.to_csv(f'{exp_name}/{file_name}', index=False)
