
import argparse
import time
import pickle


import numpy as np
import common_args
import random
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Set this first

from dataset import Dataset, ImageDataset, SubDataset, UnifiedDataset
from net import Transformer, ImageTransformer
from utils import (
    build_bandit_data_filename,
    build_bandit_model_filename,
    build_linear_bandit_data_filename,
    build_linear_bandit_model_filename,
    build_darkroom_data_filename,
    build_darkroom_model_filename,
    build_metaworld_data_filename,
    build_metaworld_model_filename,
    worker_init_fn,
)


import torch
from torchvision.transforms import transforms
print("CUDA available:", torch.cuda.is_available())
print(f"Using device: {torch.cuda.current_device()} ({torch.cuda.get_device_name(torch.cuda.current_device())})")


print("CUDA available:", torch.cuda.is_available())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
import torch.multiprocessing as mp
if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn', force=True)  # or 'forkserver'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':
    if not os.path.exists('figs/loss'):
        os.makedirs('figs/loss', exist_ok=True)
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    common_args.add_model_args(parser)
    common_args.add_train_args(parser)

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--lifelong', type=bool, default=True)
    # parser.add_argument('--continuous', type=bool, default=False)

    args = vars(parser.parse_args())

    print("Args: ", args)

    env = args['env']
    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    horizon = args['H']
    dim = args['dim']
    state_dim = dim
    action_dim = dim
    n_embd = args['embd']
    n_head = args['head']
    n_layer = args['layer']
    lr = args['lr']
    shuffle = args['shuffle']
    dropout = args['dropout']
    var = args['var']
    cov = args['cov']
    num_epochs = args['num_epochs']
    
    lin_d = args['lin_d']
    
    

    for seed in [1, 2]:
        args['seed'] = seed
        tmp_seed = seed
        if seed == -1:
            tmp_seed = 0

        torch.manual_seed(tmp_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(tmp_seed)
            torch.cuda.manual_seed_all(tmp_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(tmp_seed)
        random.seed(tmp_seed)

    
        dataset_config = {
            'n_hists': n_hists,
            'n_samples': n_samples,
            'horizon': horizon,
            'dim': dim,
        }
        model_config = {
            'shuffle': shuffle,
            'lr': lr,
            'dropout': dropout,
            'n_embd': n_embd,
            'n_layer': n_layer,
            'n_head': n_head,
            'n_envs': n_envs,
            'n_hists': n_hists,
            'n_samples': n_samples,
            'horizon': horizon,
            'dim': dim,
            'seed': seed,
        }
        



        if env.startswith('darkroom'):
            state_dim = 2
            action_dim = 5

            dataset_config.update({'rollin_type': 'uniform'})
            # path_train = build_darkroom_data_filename(
            #     env, n_envs, dataset_config, mode=0)
            path_train = 'datasets/trajs_icml_darkroom_heldout_envs100000_hists1_samples1_H100_d10_train_neurips.pkl'
            # path_test = build_darkroom_data_filename(
            #     env, n_envs, dataset_config, mode=1)
            if args['lifelong']:
                path_ood_train = build_darkroom_data_filename(
                    'darkroom_heldout', n_envs, dataset_config, mode=0)
                
                path_ood_train = 'datasets/trajs_icml_darkroom_ood8_envs100000_hists1_samples1_H100_d10_train_iter5_neurips.pkl'
                print('path ood train: ', path_ood_train)
            # use based dataset and iter1 dataset to generate iter2 model
            # use iter2 model tp generate iter2 dataset
            # use based dataset and iter2 dataset to generate iter3 model
            # use iter3 model to generate iter3 dataset
            else:
                path_ood_train = build_darkroom_data_filename(
                    'darkroom_ood8', n_envs, dataset_config, mode=0)
            
            # path_ood_test = build_darkroom_data_filename(
            #     env, n_envs+12500, dataset_config, mode=1)


            filename = build_darkroom_model_filename(env, model_config)
        elif env.startswith('metaworld'):
            state_dim = 39
            action_dim = 4
            path_train = 'datasets/ml1_pick_place_H100_q80_n2000_train.pkl'
            path_test = 'datasets/ml1_pick_place_H100_q80_n2000_test.pkl'
            #path_ood_train = 'datasets/ml1_pick_place_H100_q80_n2000_train_uncertain.pkl'
            path_ood_train = 'datasets/ml1_pick_place_H100_q80_n2000_train_iter3.pkl'

            filename = build_darkroom_model_filename(env, model_config)

        elif env.startswith('mujoco'):
            if 'hopper' in env:
                state_dim = 15
                action_dim = 4
            elif 'cartpole' in env:
                state_dim = 5
                action_dim = 1
            elif 'reacher' in env:
                action_dim = 2
                state_dim = 6
            elif 'cheetah' in env:
                action_dim = 6
                state_dim = 17
            elif 'walker' in env:
                action_dim = 6
                state_dim = 24
            elif 'quadruped' in env:
                action_dim = 12
                state_dim = 78


            filename_template = 'datasets/trajs_icml_{}.pkl'
            path_train = filename_template.format(env)

            filename = build_darkroom_model_filename(env, model_config)

        else:
            raise NotImplementedError

        config = {
            'horizon': horizon,
            'state_dim': state_dim,
            'action_dim': action_dim,
            'n_layer': n_layer,
            'n_embd': n_embd,
            'n_head': n_head,
            'shuffle': shuffle,
            'dropout': dropout,
            'test': False,
            'store_gpu': True,
            'goal': False,

        }
        if env == 'miniworld':
            config.update({'image_size': 25, 'store_gpu': False})
            model = ImageTransformer(config).to(device)
        else:
            model = Transformer(config).to(device)
            if args['lifelong']:
                tmp_filename = filename

                # if epoch < 0:
                #     model_path = f'models/{tmp_filename}.pt'
                # else:
                #     model_path = f'models/{tmp_filename}_epoch{200}.pt'


                # model_path ='models/darkroom_heldout_shufTrue_lr0.001_do0_embd32_layer4_head4_envs100000_hists1_samples1_H100_d10_seed'+str(seed)+'_epoch100_iter5.pt'
                model_path ='models/metaworld_shufTrue_lr0.001_do0_embd32_layer4_head4_envs100000_hists1_samples1_H100_d10_seed'+str(seed)+'_epoch400_icaa_iter2.pt'
                checkpoint = torch.load(model_path)
                model.load_state_dict(checkpoint)
                print('load model howard')
           

                

        params = {
            'batch_size': 128,#64, 128
            'shuffle': True,
            # 'pin_memory':True
            
        }

        log_filename = f'figs/loss/{filename}_logs.txt'
        with open(log_filename, 'w') as f:
            pass
        def printw(string):
            """
            A drop-in replacement for print that also writes to a log file.
            """
            # Use the standard print function to print to the console
            print(string)

            # Write the same output to the log file
            with open(log_filename, 'a') as f:
                print(string, file=f)




        if env == 'miniworld':
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
            ])



            params.update({'num_workers': 16,
                    'prefetch_factor': 2,
                    'persistent_workers': True,
                    'pin_memory': True,
                    'batch_size': 64,
                    'worker_init_fn': worker_init_fn,
                })


            printw("Loading miniworld data...")
            train_dataset = ImageDataset(paths_train, config, transform)
            test_dataset = ImageDataset(paths_test, config, transform)
            printw("Done loading miniworld data")
        else:
            # print('load dataset')
            # train_dataset = UnifiedDataset(path_train, config, store_gpu=True)
            # print('finish loading')
            # if not env.startswith('mujoco'):
            #     test_dataset = UnifiedDataset(path_test, config, store_gpu=True)
            # print("Original train_dataset size:", len(train_dataset))
            # print("test_dataset size:", len(test_dataset))


            # print('path train: ', path_train)
            # print('path ood train: ', path_ood_train)
            # # mix dataset
            train_dataset = UnifiedDataset(path_train, config, store_gpu=True)
            print("Original train_dataset size:", len(train_dataset))
            test_dataset = UnifiedDataset(path_test, config, store_gpu=True)
            train_ood_dataset = UnifiedDataset(path_ood_train, config, store_gpu=True)
            print("test_dataset size:", len(test_dataset))
            print("train_ood_dataset size:", len(train_ood_dataset))

            # Combine the full OOD dataset
            train_dataset.concatenate(train_ood_dataset)

            print("Final train_dataset size:", len(train_dataset))
            
            

            
          

        train_loader = torch.utils.data.DataLoader(train_dataset, **params)
        if not env.startswith('mujoco'):
            test_loader = torch.utils.data.DataLoader(test_dataset, **params)

        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

        if env.startswith('darkroom'):
            print('discrete')
            loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
            
        else:
            
            print('continuous')
            loss_fn = torch.nn.MSELoss(reduction='sum')

        test_loss = []
        train_loss = []

        printw("Num train batches: " + str(len(train_loader)))
        if not env.startswith('mujoco'):
            printw("Num test batches: " + str(len(test_loader)))

        start_time = time.time()
        print('filename: ', filename)
        torch.save(model.state_dict(),
                        f'models/{filename}_epoch{0}_icaa_iter3.pt')
        for epoch in range(num_epochs):
            
            # EVALUATION
            printw(f"Epoch: {epoch + 1}")
            
            if not env.startswith('mujoco'):
                with torch.no_grad():
                    epoch_test_loss = 0.0
                    for i, batch in enumerate(test_loader):
                        print(f"Batch {i} of {len(test_loader)}", end='\r')
                        batch = {k: v.to(device) for k, v in batch.items()}

                        true_actions = batch['optimal_actions']
                        pred_actions = model(batch)
                        true_actions = true_actions.unsqueeze(
                            1).repeat(1, pred_actions.shape[1], 1)
                        true_actions = true_actions.reshape(-1, action_dim)
                        pred_actions = pred_actions.reshape(-1, action_dim)

                        loss = loss_fn(pred_actions, true_actions)
                        epoch_test_loss += loss.item() / horizon
                    
                    # batch_size = 32, for--> 50 times, indicate average batch_size loss: 0.32--> 0.01
                    # batch_size = 64  for--> 25 times, average batch_size loss:0.64

                    
                    # inidacte average single data loss: 0.32/32 = 0.01

                    test_loss.append(epoch_test_loss / len(test_dataset))
                    printw(f"\tTest loss: {test_loss[-1]}")
            # printw(f"\tEval time: {end_time - start_time}")


            # TRAINING
            epoch_train_loss = 0.0
            print(f"Using device: {torch.cuda.current_device()} ({torch.cuda.get_device_name()})")


            for i, batch in enumerate(train_loader):
                print(f"Batch {i} of {len(train_loader)}", end='\r')
                
                batch = {k: v.to(device) for k, v in batch.items()}
                true_actions = batch['optimal_actions']
                pred_actions = model(batch)
                true_actions = true_actions.unsqueeze(
                    1).repeat(1, pred_actions.shape[1], 1)
                true_actions = true_actions.reshape(-1, action_dim)
                pred_actions = pred_actions.reshape(-1, action_dim)

                optimizer.zero_grad()
                loss = loss_fn(pred_actions, true_actions)
                loss.backward()
                optimizer.step()
                epoch_train_loss += loss.item() / horizon

            train_loss.append(epoch_train_loss / len(train_dataset))
            end_time = time.time()
            print('Train loss: ', train_loss[-1])
            # printw(f"\tTrain time: {end_time - start_time}")


            # LOGGING
            
            if (epoch + 1) % 50 == 0 or (env == 'linear_bandit' and (epoch + 1) % 10 == 0):
                torch.save(model.state_dict(),
                        f'models/{filename}_epoch{epoch+1}_icaa_iter3.pt')

            # # PLOTTING
            # if (epoch + 1) % 10 == 0:
            #     printw(f"Epoch: {epoch + 1}")
            #     printw(f"Test Loss:        {test_loss[-1]}")
            #     printw(f"Train Loss:       {train_loss[-1]}")
            #     printw("\n")

            #     plt.yscale('log')
            #     plt.plot(train_loss[1:], label="Train Loss")
            #     plt.plot(test_loss[1:], label="Test Loss")
            #     plt.legend()
            #     plt.savefig(f"figs/loss/{filename}_train_loss.png")
            #     plt.clf()

        torch.save(model.state_dict(), f'models/{filename}_icaa_iter3.pt')
        print("Done.")
        print('training time: ', time.time()-start_time )
