import torch
import argparse
import pickle
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch import tensor as tt
import torch.nn as nn
import torch.nn.functional as F
import time
from statsmodels.distributions.empirical_distribution import ECDF



# Parse command-line arguments
parser = argparse.ArgumentParser(description="Train a model with a specified dataset.")
parser.add_argument("--dataset", type=str, required=True, help="Dataset name (e.g., 'magic_ecdf')")
parser.add_argument("--epochs", type=str, default=50, help="Number of epochs to train the model.")
parser.add_argument("--cv_seed", type=int, default=50, help="seeds to run for")
args = parser.parse_args()

cv_seeds = int(args.cv_seed)

# Use dataset name to construct file paths and variable names
dataset_name = args.dataset
csv_path = f"Data/{dataset_name}.csv"

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check if the dataset file exists
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"Dataset file '{csv_path}' not found.")






#### IGC model

class SoftRank(nn.Module):
    """Differentiable ranking layer"""
    def __init__(self, alpha=1000.0):
        super(SoftRank, self).__init__()
        self.alpha = alpha # constant for scaling the sigmoid to approximate sign function, larger values ensure better ranking, overflow is handled properly by PyTorch

    def forward(self, inputs):
        # input is a ?xSxD tensor, we wish to rank the S samples in each dimension per each batch
        # output is  ?xSxD tensor where for each dimension the entries are (rank-0.5)/N_rank
        x = inputs.unsqueeze(-1) #(?,S,D) -> (?,S,D,1)
        x_2 = x.repeat(1, 1, 1, x.shape[1]) # (?,S,D,1) -> (?,S,D,S) (samples are repeated along axis 3, i.e. the last axis)
        x_1 = x_2.transpose(1, 3) #  (?,S,D,S) -> (?,S,D,S) (samples are repeated along axis 1)
        return torch.transpose(torch.sum(torch.sigmoid(self.alpha*(x_1-x_2)), dim=1), 1, 2)/(torch.tensor(x.shape[1], dtype=torch.float32))
class IGC(nn.Module):
    
    def __init__(self, hidden_size=100, layers_number=2, output_size=2):
        super(IGC, self).__init__()
        self.dim_latent = 3 * output_size
        self.hidden_size = hidden_size
        self.layers_nuber = layers_number
        self.output_size = output_size
        self.linear_in = nn.Linear(in_features=self.dim_latent, out_features=self.hidden_size) 
        self.linear = nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size)
        self.linear_out = nn.Linear(in_features=self.hidden_size, out_features=self.output_size)
        self.marginal_cdfs = None
        self.ecdf_10e6_samples = None

    def forward_train(self, z):
        '''
        Input noise z with shape (M,dim_latent)\\
        Outputs (u,v) pairs with shape (M,output_size=2), while ensuring u and v each have uniform marginals.
        '''
        y = torch.relu(self.linear_in(z))
        for layer in range(self.layers_nuber):
            y = torch.relu(self.linear(y))
        y = self.linear_out(y).unsqueeze(0)
        u = SoftRank()(y).squeeze(0)
        return u
        
    def Energy_Score_pytorch(self,beta, observations_y, simulations_Y):
        n = len(observations_y)
        m = len(simulations_Y)

        # First part |Y-y|. Gives the L2 dist scaled by power beta. Is a vector of length n/one value per location.
        diff_Y_y = torch.pow(
            torch.norm(
                (observations_y.unsqueeze(1) -
                simulations_Y.unsqueeze(0)).float(),
                dim=2,keepdim=True).reshape(-1,1),
            beta)

        # Second part |Y-Y'|. 2* because pdist counts only once.
        diff_Y_Y = 2 * torch.pow(
            nn.functional.pdist(simulations_Y),
            beta)
        Energy = 2 * torch.mean(diff_Y_y) - torch.sum(diff_Y_Y) / (m * (m - 1))
        return Energy


    def forward(self, n_samples):
        ''' 
        Function to sample from the copula, once training is done.

        Input: n_samples - number of samples to generate
        Output: torch.tensor of shape (n_samples, output_size) on copula space.
        '''
        with torch.no_grad():
            if self.marginal_cdfs is None:
                self.marginal_cdfs = []
                # sample 10^6 points from the latent space and compute empirical marginal cdfs
                z = torch.randn(10**6, self.dim_latent)
                y = torch.relu(self.linear_in(z))
                for layer in range(self.layers_nuber):
                    y = torch.relu(self.linear(y))
                y = self.linear_out(y) # samples used to approximate cdfs
                for dim in range(y.shape[1]):
                    ecdf = ECDF(y[:, dim].numpy())
                    self.marginal_cdfs.append(ecdf)
                self.ecdf_10e6_samples = y
            # sample the latent space and apply ecdfs
            z = torch.randn(n_samples, self.dim_latent)
            y = torch.relu(self.linear_in(z))
            for layer in range(self.layers_nuber):
                y = torch.relu(self.linear(y))
            y = self.linear_out(y)
            for dim in range(y.shape[1]):
                y[:, dim] = torch.tensor(self.marginal_cdfs[dim](y[:, dim].numpy()), dtype=torch.float32)
            return y
        
# Load the dataset
X_ecdf = pd.read_csv(csv_path).values.astype(np.float32)

sims_igc_per_seed = []

start_time = time.time()    


for cv_seed in [cv_seeds]:

    # Split into train and test sets
    X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.2, random_state=cv_seed)
    if dataset_name == 'mnist_ecdf':  
        X_ecdf_train, X_ecdf_test, _, _ = train_test_split(X_ecdf, X_ecdf, test_size=0.5, random_state=cv_seed)
    elif dataset_name == 'robocup_train':
        X_ecdf_train = X_ecdf

    # smaller IGC model, or larger with depths 6
    igc_cop = IGC(hidden_size=512, layers_number=2, output_size=X_ecdf.shape[1]).to(device)

    u_obs = torch.tensor(X_ecdf_train).to(device)

    optimizer = torch.optim.Adam(igc_cop.parameters())

    for i in (range(int(args.epochs))):
        optimizer.zero_grad()
        u = igc_cop.forward_train(torch.randn((100, igc_cop.dim_latent)).to(device))
        loss = igc_cop.Energy_Score_pytorch(1, u_obs[np.random.choice(range(u_obs.shape[0]),100,replace=True)], u)
        loss.backward()
        optimizer.step()
        
    sims_igc = [igc_cop.forward_train(torch.randn((100, igc_cop.dim_latent)).to(device)).cpu().detach().numpy() for _ in range(10)]
    sims_igc = np.vstack(sims_igc)
    sims_igc_per_seed.append(sims_igc)
    print('cv_seed:', cv_seed, 'sims_igc_per_seed shape:', sims_igc.shape)

# save as npy array
torch.save(igc_cop.state_dict(),f'Model_weights/IGC/igc_weights_2times512_{dataset_name}_{cv_seeds}_iters{int(args.epochs)}.pt')
print('saved: ', f'Model_weights/IGC/igc_weights_2times512_{dataset_name}_{cv_seeds}_iters{int(args.epochs)}.pt')

sims_igc_per_seed = np.array(sims_igc_per_seed)
np.save(f'Model_samples/IGC/{dataset_name}_igc_sims_IGC_smaller_{cv_seeds}_iters{int(args.epochs)}.npy', sims_igc_per_seed)
print('sims_igc_per_seed shape:', sims_igc_per_seed.shape, 'saved: ', f'Model_samples/IGC/{dataset_name}_igc_sims_IGC_smaller_{cv_seeds}_iters{int(args.epochs)}.npy')
print('time taken to train IGC field:', time.time()-start_time, 's, dataset:', dataset_name, 'cv_seed:', cv_seed , 'epochs:', args.epochs)
