import lightning as L
import torch
import argparse
import sys 
sys.path.append(".")
from source.reward.reward_model import R_goal_model
import pandas as pd 
from source.collection.collect_data_minigrid import state_to_obs_spe
from sklearn.model_selection import train_test_split
from lightning.pytorch.loggers import TensorBoardLogger
from datetime import datetime
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np 
import glob
import os 
from multiprocessing import Pool
from tqdm import tqdm
from source.reward.building_goal_detection_dataset import filter_state
from source.reward.building_goal_detection_dataset import remove_abstract

path_folder = "./data/datasets/"

class Dataset(torch.utils.data.Dataset):
    def __init__(self, X, y=None):
        self.X = X
        if y is not None:
            self.y = torch.tensor(y, dtype=torch.float32)
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        grid = torch.tensor(self.X[idx][0].transpose((2,0,1))).float()
        inventory = torch.tensor(self.X[idx][1]).float()
        return [grid,inventory], self.y[idx]

class Dataset_test(Dataset):
    def __getitem__(self, idx):
        grid = torch.tensor(self.X[idx][0].transpose((2,0,1))).float()
        inventory = torch.tensor(self.X[idx][1]).float()
        return [grid,inventory]


def loss(y_hat, Y):
    
    mse = torch.nn.functional.mse_loss(y_hat,Y,reduction="none")
    
    loss = mse
    
    return torch.mean(loss)

def load_goal_achieved_dataset(dataset, goal, env_id, grid_size):
    df = pd.read_csv(f"{path_folder}/{dataset}/env/{env_id}/{goal}/goal_detection_dataset.csv")
    X = [state_to_obs_spe(x, grid_size) for x in df["state"].to_list()]
    
    np.set_printoptions(threshold=sys. maxsize)    
    Y = df["goal achieved"].to_list()
    return X,Y

def load_observation(dataset, env_id, grid_size, goal):

    df = pd.read_csv(f"{path_folder}/{dataset}/env/{env_id}/{goal}/collected_data_abstract.csv")
    X = [state_to_obs_spe(remove_abstract(filter_state(x, goal), goal), grid_size) for x in df["next_obs"].to_list()]

    return X, df


def train_model_per_goal(goal, env_id, args):
    device = np.random.randint(3)
    
    
    #Checking if training a model is needed
    training_set = pd.read_csv(f"{path_folder}/{args.name}/env/{env_id}/{goal}/goal_detection_dataset.csv")
  
    
    len_training_set = len(training_set["state"].to_list())

    df = pd.read_csv(f"{path_folder}/{args.name}/env/{env_id}/{goal}/collected_data_abstract.csv")

    len_obs_total = len(set(df["next_obs"].drop_duplicates().to_list()) | set(df["obs"].drop_duplicates().to_list()))

    
    if len_training_set >= len_obs_total :
        rewards = []
        states = training_set["state"].to_list()
        goal_achieveds = training_set["goal achieved"].to_list()
        for obs in df["next_obs"].to_list():
            for i,state in enumerate(states):
                if obs == state:
                    rewards.append(goal_achieveds[i])
                    break
    
        df["reward"] = rewards

        df.to_csv(f"{path_folder}/{args.name}/env/{env_id}/{goal}/data_with_reward.csv", index=False)

                
                
    else:
        
        X,Y = load_goal_achieved_dataset(args.name, goal, env_id, grid_size=args.grid_size)

        X_test, df = load_observation(args.name, env_id, args.grid_size, goal)

        test_dataset = Dataset_test(X_test)
        
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

        X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.1, random_state=args.seed)
        
        if 1 not in y_train:
            if 1 not in y_val:
                raise Exception(f"goal achieved dataset does not has positive label for goal {goal}")

            i = y_val.index(1)
            X_train.append(X_val[i])
            del X_val[i]
            y_train.append(1)
            del y_val[i]
            
        else:
            if 1 not in y_val:
                i = y_train.index(1)
                X_val.append(X_train[i])
                del X_train[i]
                y_val.append(1)
                del y_train[i]
        
          
        print(f"Number of 1 in train : {y_train.count(1)}")
        print(f"Number of 1 in val : {y_val.count(1)}")
            
            
        train_dataset = Dataset(X_train, y_train)
        val_dataset = Dataset(X_val, y_val)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,)
                
        logger = TensorBoardLogger(save_dir=f'''results/training_reward_goal_model/logs/''', name=f'''{args.name}_{goal}_{args.seed}_{datetime.now().strftime("%Y%m%d%H%M%S")}''')
        
        logger.log_hyperparams(args.__dict__)

        model = R_goal_model(loss_fn = lambda y_hat, Y : loss(y_hat, Y), lr=args.lr,sample=train_dataset[0][0], dropout=0.0)

        checkpoint_callback = ModelCheckpoint(dirpath=f'''results/training_reward_goal_model/logs/{args.name}_{goal}_{args.seed}_{datetime.now().strftime("%Y%m%d%H%M%S")}''', save_top_k=1, monitor="val_loss", filename="best_checkpoint")
        
        trainer = L.Trainer(max_epochs=args.epochs,  devices = [device], accelerator="gpu", logger=logger, deterministic=True,callbacks=[checkpoint_callback], enable_progress_bar=True, enable_model_summary=False)

        trainer.fit(model, train_loader, val_loader)

   
        if checkpoint_callback._last_global_step_saved == 1:
            rewards = torch.cat(trainer.predict(model, test_loader)).squeeze(1).cpu().numpy()
        else:
            rewards = torch.cat(trainer.predict(model, test_loader, ckpt_path= checkpoint_callback.best_model_path)).squeeze(1).cpu().numpy()

        df["reward"] = rewards

        df.to_csv(f"{path_folder}/{args.name}/env/{env_id}/{goal}/data_with_reward.csv", index=False)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the reward model")
    parser.add_argument("--name", type=str, default="BabyAI", help="The name of the dataset")
    parser.add_argument("--seed", type=int, default=0, help="The seed to use")
    parser.add_argument("--goal", type=str, nargs="+", default=[], help="The goal to use")
    parser.add_argument("--epochs", type=int, default=200, help="The number of epochs)")
    parser.add_argument("--batch_size", type=int, default=20000, help="The batch size")
    parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate")
    parser.add_argument("--env_id", type=str, nargs= "+", default=[], help="The name of the dataset")
    parser.add_argument("--grid_size", type=int, nargs="+", default=[22,22])
    args = parser.parse_args()
    L.seed_everything(args.seed)
    
    
    if args.env_id == []:
        args.env_id = [x.split("/")[-1] for x in glob.glob(f"{path_folder}/{args.name}/env/*")]
    else:
        args.env_id = args.env_id
    
    list_goals_env_id = []
    
    
    for env_id in args.env_id:
        if args.goal != []:
            goals = args.goal
        else:
            goals = [x.split("/")[-1] for x in glob.glob(f"{path_folder}/{args.name}/env/{env_id}/*") if ( "csv" not in x and "goal_detection_dataset.csv" in os.listdir(x) and "data_with_reward.csv" not in os.listdir(x)) ] 
        for g in goals:
            list_goals_env_id.append((g,env_id))
    
    
    
    
    n_job = 50
    def run_process(p):
        goal = p[0]
        env_id = p[1]
        train_model_per_goal(goal, env_id, args)
    
    pool = Pool(n_job)
    for _ in tqdm(pool.imap_unordered(run_process, list_goals_env_id), total=len(list_goals_env_id)):
        pass
    
    