import argparse
import collections
import json
import os
import random
import sys
import time
import uuid
import timm
import pickle  # Import pickle

import numpy as np
import PIL
import torch
import torchvision
import torch.utils.data

from domainbed import datasets
from domainbed import hparams_registry
from domainbed import algorithms
from domainbed.lib import misc
from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Domain generalization')
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--dataset', type=str, default="RotatedMNIST")
    parser.add_argument('--algorithm', type=str, default="ERM")
    parser.add_argument('--task', type=str, default="domain_generalization",
                        choices=["domain_generalization", "domain_adaptation"])
    parser.add_argument('--hparams', type=str,
                        help='JSON-serialized hparams dict')
    parser.add_argument('--hparams_seed', type=int, default=0,
                        help='Seed for random hparams (0 means "default hparams")')
    parser.add_argument('--trial_seed', type=int, default=0,
                        help='Trial number (used for seeding split_dataset and '
                             'random_hparams).')
    parser.add_argument('--seed', type=int, default=0,
                        help='Seed for everything else')
    parser.add_argument('--steps', type=int, default=None,
                        help='Number of steps. Default is dataset-dependent.')
    parser.add_argument('--checkpoint_freq', type=int, default=None,
                        help='Checkpoint every N steps. Default is dataset-dependent.')
    parser.add_argument('--test_envs', type=int, nargs='+', default=[0])
    parser.add_argument('--output_dir', type=str, default="train_output")
    parser.add_argument('--holdout_fraction', type=float, default=0.2)
    parser.add_argument('--uda_holdout_fraction', type=float, default=0,
                        help="For domain adaptation, % of test to use unlabeled for training.")
    parser.add_argument('--skip_model_save', action='store_true')
    parser.add_argument('--save_model_every_checkpoint', action='store_true')
    args = parser.parse_args()

    # If we ever want to implement checkpointing, just persist these values
    # every once in a while, and then load them from disk here.
    start_step = 0
    algorithm_dict = None

    os.makedirs(args.output_dir, exist_ok=True)
    sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt'))
    sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt'))

    print("Environment:")
    print("\tPython: {}".format(sys.version.split(" ")[0]))
    print("\tPyTorch: {}".format(torch.__version__))
    print("\tTorchvision: {}".format(torchvision.__version__))
    print("\tCUDA: {}".format(torch.version.cuda))
    print("\tCUDNN: {}".format(torch.backends.cudnn.version()))
    print("\tNumPy: {}".format(np.__version__))
    print("\tPIL: {}".format(PIL.__version__))

    print('Args:')
    for k, v in sorted(vars(args).items()):
        print('\t{}: {}'.format(k, v))

    if args.hparams_seed == 0:
        hparams = hparams_registry.default_hparams(args.algorithm, args.dataset)
    else:
        hparams = hparams_registry.random_hparams(args.algorithm, args.dataset,
                                                  misc.seed_hash(args.hparams_seed, args.trial_seed))
    if args.hparams:
        hparams.update(json.loads(args.hparams))

    print('HParams:')
    for k, v in sorted(hparams.items()):
        print('\t{}: {}'.format(k, v))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    if args.dataset in vars(datasets):
        dataset = vars(datasets)[args.dataset](args.data_dir,
                                               args.test_envs, hparams)
    else:
        raise NotImplementedError

    in_splits, out_splits = [], []
    for env_i, env in enumerate(dataset):
        out, in_ = misc.split_dataset(env, int(len(env) * args.holdout_fraction),
                                      misc.seed_hash(args.trial_seed, env_i))
        in_splits.append(in_)
        out_splits.append(out)

    dric_envs = {}
    for i in range(len(in_splits)):
        in_loader = FastDataLoader(dataset=in_splits[i], batch_size=64, num_workers=dataset.N_WORKERS)
        out_loader = FastDataLoader(dataset=out_splits[i], batch_size=64, num_workers=dataset.N_WORKERS)

        batches = list(in_loader) + list(out_loader)
        dric_envs[f'env{i+1}'] = batches

    dric_envs_inputs = {}

    for i in range(len(in_splits)):
        in_loader = FastDataLoader(dataset=in_splits[i], batch_size=64, num_workers=dataset.N_WORKERS)
        out_loader = FastDataLoader(dataset=out_splits[i], batch_size=64, num_workers=dataset.N_WORKERS)

        batches = list(in_loader) + list(out_loader)
        env_inputs = []

        for inputs, _ in batches:  
            env_inputs.append(inputs)

        dric_envs_inputs[f'env{i+1}'] = env_inputs

        print(f"DRIC Environment: env{i+1}")
        for j, inputs in enumerate(env_inputs):
            print(f"  Batch {j+1}: Inputs shape: {inputs.shape}")
        print(f"Total number of batches: {len(env_inputs)}\n")

    output_dir = "train_output"  
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, 'dric_envs_inputs.pkl'), 'wb') as f:
        pickle.dump(dric_envs_inputs, f)

    print(f"Input data saved to {os.path.join(output_dir, 'dric_envs_inputs.pkl')}")

#