import argparse
import os
import os.path as osp
import numpy as np
import pandas as pd
import optuna 

import torch
import torch.optim as optim
from torch_geometric.loader import DataLoader
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from plmdata_repeat.dataset_pyg import PygPolymerDataset
from plmdata_repeat.data_aug import csvcatg
from .model import GNN
from .utils import seed_torch, training, validate, print_info

class MainGRIN():
    def __init__(self, args):
        self.args = args
        self.root = args.root
        if len(args.train_rep) > 1:
            self.r = args.train_rep
        elif len(args.train_rep) == 1:
            self.r = args.train_rep[0]
        else:
            raise ValueError("Invalid training repetition")
        self.num_layer = args.num_layer
        self.task_name = args.task_name
        self.model_type = args.model_type
        self.test_rep = args.test_rep
        self._use_ck = args.use_ck
        self.seed_num = args.seeds
        self.polymer_type = args.polymer_type
        res_path = f'res/{self.model_type}/'
        if not osp.exists(res_path):
            os.makedirs(res_path)
        if self.task_name in ['tg','mt','density','o2','EA','IP']:
            if isinstance(self.r, list):
                
                self._use_concat_train = True
                csvcatg(self.root,self.polymer_type,self.task_name,self.r)

                # path for saving results for each model
                self.ck_path = f'model/checkpoints/{self.task_name}/{self.model_type}/merge/{self.num_layer}/{self.r}'
                print(self.ck_path)
                self.res_csv_name = osp.join(res_path,f'merge_{self.task_name}.csv')
                
            else:
                self._use_concat_train = False
                self.ck_path = f'model/checkpoints/{self.task_name}/{self.model_type}/single/{self.r}'
                self.res_csv_name = osp.join(res_path,f'single_{self.task_name}.csv')
        else:
            raise ValueError("Invalid task name")

        if not osp.exists(self.ck_path):
            if not self._use_ck:
                os.makedirs(self.ck_path)
            else:
                raise ValueError(f"************ No such {self.model_type} model exist! ************")
        print("************************** Work on {} task use {} model trained on {} repeat times **************************".format(self.task_name,self.model_type,self.r))

    def test(self, model, device):
        # test on different test sets
        test_res = []
        for r in self.test_rep:
            dataset = PygPolymerDataset(polymer_type=self.polymer_type, root=self.root, set_name = "test", repeat_times=r,task_name=self.task_name)
            test_loader = DataLoader(
                dataset,
                batch_size=1,
                shuffle=False,
                num_workers=0,
            )
            test_perf = validate(model,test_loader,device)
            print_info('Test result of test set with {} repeating units'.format(r), test_perf)
            test_res.append(test_perf)

        print('Finished testing!')

        return (
            test_res
        )


    def model_seed(self,i):   

        train_dataset = PygPolymerDataset(polymer_type=self.polymer_type, root=self.root, set_name="train",task_name=self.task_name,repeat_times=self.r,_use_concat_train = self._use_concat_train)
        valid_dataset = PygPolymerDataset(polymer_type=self.polymer_type, root=self.root, set_name="valid",task_name=self.task_name,repeat_times=self.r,_use_concat_train = self._use_concat_train)
        train_loader = DataLoader(train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers)
        valid_loader = DataLoader(valid_dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.num_workers)
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # initiate model
        if self.args.model_type == "gin":
            model = GNN(
                gnn_type="gin",
                repeat_time = train_dataset.repeat_times,
                num_task=1,
                num_layer=self.args.num_layer,
                emb_dim=self.args.emb_dim,
                drop_ratio=self.args.drop_ratio,
            ).to(device)
        elif self.args.model_type == "gcn":
            model = GNN(
                gnn_type="gcn",
                repeat_time = train_dataset.repeat_times,
                num_task=1,
                num_layer=self.args.num_layer,
                emb_dim=self.args.emb_dim,
                drop_ratio=self.args.drop_ratio,
            ).to(device)
        else:
            raise ValueError("Invalid GNN type")
        optimizer = optim.Adam(model.parameters(), lr=self.args.lr)

        ck_name = osp.join(self.ck_path,'model-{}.pt'.format(i))
        print(ck_name)       
        
        best_train, best_valid, best_params = None, None, None
        if self._use_ck:
            # checkpoint
            if not osp.exists(ck_name):
                raise ValueError("************************** No model found! **************************")
            else:
                state = torch.load(ck_name)
                model.load_state_dict(state['model'])
        else:
            # Training settings
            best_epoch = 0
            print("Start training...")

            for epoch in range(self.args.epochs):
                training(epoch,model, train_loader, optimizer, device)
                #training(model, train_loader, optimizer, device)
                valid_perf = validate(model, valid_loader,device)
                if epoch == 0 or valid_perf['rmse'] <  best_valid['rmse']:
                    train_perf = validate(model, train_loader, device)
                    best_params = parameters_to_vector(model.parameters())

                    best_valid, best_train, best_epoch = valid_perf, train_perf, epoch

                    state = {"epoch": best_epoch, "model": model.state_dict()}
                    torch.save(state, ck_name)
                else:   
                    # save checkpoints
                    if epoch - best_epoch > self.args.patience:
                        break

            print('Finished training of {}th model of {} repeat times! Best validation results from epoch {}.'.format(i,self.r,best_epoch))
            print('Model saved as {}.'.format(ck_name))
            print_info('train', best_train)
            print_info('valid', best_valid)

            vector_to_parameters(best_params, model.parameters())
        
        return (model, device, best_train, best_valid)

    def main(self):
        # testing
        if self._use_ck:
            # information reserved for testing results
            results = {
                    "model": [],
                    "layer": [],
                    "train_repeat_time":[],
                    "test_repeat_time":[],
                    "test_mae":[],
                    "test_rmse":[],
                    "test_r2":[],
                }

            df = pd.DataFrame(results)
            num_t = len(self.test_rep)
            dfs = [df] * num_t
        for i in range(self.seed_num):
            seed_torch(i)
            model,device,best_train,best_valid = self.model_seed(i)
            if self._use_ck:
                # start testing
                print("Start testing...")
                test_res = self.test(model,device)
                for j in range(len(test_res)):
                    test_perf = test_res[j]
                    cur_df = dfs[j]
                    new_results = {
                        "model": self.model_type,
                        "layer": self.num_layer,
                        "train_repeat_time":str(self.r),
                        "test_repeat_time":str(self.test_rep[j]),
                        "test_mae":test_perf["mae"],
                        "test_rmse":test_perf["rmse"],
                        "test_r2":test_perf["r2"],
                    }
                    new_df = pd.DataFrame([new_results])
                    cur_df = pd.concat([cur_df, new_df], ignore_index=True)
                    dfs[j] = cur_df
        if not self._use_ck:
            print('Finished training for all {} models of {} repeat times!'.format(self.seed_num, self.r))
        else:
            for df in dfs:
                # Calculate mean and std, and format them as "mean±std".
                summary_cols = ["model", "layer", "train_repeat_time","test_repeat_time"]
                df_mean = df.groupby(summary_cols).mean().round(4)
                df_std = df.groupby(summary_cols).std().round(4)

                df_mean = df_mean.reset_index()
                df_std = df_std.reset_index()
                df_summary = df_mean[summary_cols].copy()
                # Format 'train', 'valid' columns as "mean±std".
                for metric in ['r2','rmse','mae']:
                    col_name = 'test_'+metric
                    df_summary[col_name] = df_mean[col_name].astype(str) + "±" + df_std[col_name].astype(str)

                # Save and print the summary DataFrame.    
                if osp.exists(self.res_csv_name):
                    df_summary.to_csv(self.res_csv_name, mode="a", header=False, index=False)
                else:
                    df_summary.to_csv(self.res_csv_name, index=False)
                print(df_summary)
            print('Finished testing for all {} models of {} repeat times!'.format(self.seed_num, self.r))

        
if __name__ == "__main__":
    pass

        