import collections
import json
import time
import copy
from pathlib import Path
import os

import numpy as np
import torch
import torch.utils.data

from domainbed.datasets import get_dataset, split_dataset
from domainbed import algorithms
from domainbed.evaluator_t import Evaluator
from domainbed.lib import misc
from domainbed.lib import swa_utils
from domainbed.lib.query import Q
from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader
from domainbed import swad as swad_module

def json_handler(v):
    if isinstance(v, (Path, range)):
        return str(v)
    raise TypeError(f"`{type(v)}` is not JSON Serializable")


def train(train_envs, args, hparams, n_steps, checkpoint_freq, logger, writer, target_env=None, check_teacher=None):
    logger.info("")
    
    if torch.cuda.is_available():
        device = args.device
    else:
        device = "cpu"
    
    if check_teacher is not None:
        Env = check_teacher
     
    #######################################################
    # setup dataset & loader
    #######################################################
    # args.real_test_envs = test_envs  # for log
    algorithm_class = algorithms.get_algorithm_class(args.algorithm)

    dataset, in_splits, out_splits = get_dataset([0], args, hparams, algorithm_class)
    test_splits = []
    n_envs = len(dataset)
    test_envs = sorted(set(range(n_envs)) - set(train_envs))
    
    # logger.nofmt("Dataset:")
    # logger.nofmt(f"\t[{args.dataset}] #envs={len(dataset)}, #classes={dataset.num_classes}")
    # for i, env_property in enumerate(dataset.environments):
    #     logger.nofmt(f"\tenv{i}: {env_property} (#{len(dataset[i])})")
    # logger.nofmt("")
    
    logger.info(f"Seed:{args.seed}!!!!!!!!!!!!!")


# 查一下in_splits和out_splits
    if target_env is not None:
        testenv_name = f"te_{dataset.environments[target_env]}"
        logger.info(f"Target env = {target_env}")
    else:
        testenv_properties = [str(dataset.environments[i]) for i in test_envs]
        testenv_name = "te_" + "_".join(testenv_properties)

    logger.info(
        "Testenv name escaping {} -> {}".format(testenv_name, testenv_name.replace(".", ""))
    )
    testenv_name = testenv_name.replace(".", "")
    logger.info(f"Test envs = {test_envs}, name = {testenv_name}")


    iterator = misc.SplitIterator(test_envs)
    batch_sizes = np.full([n_envs], hparams["batch_size"], dtype=int)
    # batch_sizes[train_envs] = hparams["batch_size"]
    batch_sizes[test_envs] = 0
    batch_sizes = batch_sizes.tolist()

    logger.info(f"Batch sizes for each domain: {batch_sizes} (total={sum(batch_sizes)})")

    # calculate steps per epoch
    steps_per_epochs = [
        len(env) / batch_size
        for (env, _), batch_size in iterator.train(zip(in_splits, batch_sizes))
    ]
    #breakpoint()
    steps_per_epoch = min(steps_per_epochs)
    # epoch is computed by steps_per_epoch
    prt_steps = ", ".join([f"{step:.2f}" for step in steps_per_epochs])
    logger.info(f"steps-per-epoch for each domain: {prt_steps} -> min = {steps_per_epoch:.2f}")

    # setup loaders
    train_loaders = [
        InfiniteDataLoader(
            dataset=env,
            weights=env_weights,
            batch_size=batch_size,   
            num_workers=dataset.N_WORKERS,
        )
        for (env, env_weights), batch_size in iterator.train(zip(in_splits, batch_sizes))
    ]
    #breakpoint()

    # setup eval loaders
    eval_loaders_kwargs = []
    # # for i, (env, _) in enumerate(in_splits[test_envs[0]] + out_splits[test_envs[0]] + test_splits):
    # for i, (env, _) in enumerate(in_splits + out_splits + test_splits):
    #     if i%n_envs == test_envs[0]:
    #         # print(i)
    #         batchsize = hparams["test_batchsize"]
    #         loader_kwargs = {"dataset": env, "batch_size": batchsize, "num_workers": dataset.N_WORKERS}
    #         if args.prebuild_loader:
    #             loader_kwargs = FastDataLoader(**loader_kwargs)
    #         eval_loaders_kwargs.append(loader_kwargs)
    #     eval_loaders_kwargs = []
        

    # eval_weights = [None for _, weights in (in_splits + out_splits + test_splits)]
    # eval_loader_names = ["env{}_in".format(i) for i in test_envs]
    # eval_loader_names += ["env{}_out".format(i) for i in test_envs]
    # #eval_loader_names += ["env{}_inTE".format(i) for i in test_envs]
    # eval_meta = list(zip(eval_loader_names, eval_loaders_kwargs, eval_weights))
    eval_loaders_kwargs = []
    for i, (env, _) in enumerate(in_splits + out_splits + test_splits):
        batchsize = hparams["test_batchsize"]
        loader_kwargs = {"dataset": env, "batch_size": batchsize, "num_workers": dataset.N_WORKERS}
        if args.prebuild_loader:
            loader_kwargs = FastDataLoader(**loader_kwargs)
        eval_loaders_kwargs.append(loader_kwargs)

    eval_weights = [None for _, weights in (in_splits + out_splits + test_splits)]
    eval_loader_names = ["env{}_in".format(i) for i in range(len(in_splits))]
    eval_loader_names += ["env{}_out".format(i) for i in range(len(out_splits))]
    eval_loader_names += ["env{}_inTE".format(i) for i in range(len(test_splits))]
    eval_meta = list(zip(eval_loader_names, eval_loaders_kwargs, eval_weights))

    # # Filter out redundant evaluations
    # if args.save_predictions: 
    #     filtered_eval_meta = []
    #     for item in eval_meta:
    #         name = item[0] # e.g., "env0_out"
    #         try:
    #             env_idx_str = name.split('_')[0].replace('env', '')
    #             env_idx = int(env_idx_str)
                
    #             if env_idx in train_envs:
    #                 filtered_eval_meta.append(item)
    #         except:
    #             filtered_eval_meta.append(item)
        
    #     logger.info(f"Filtering evaluation environments.")
    #     eval_meta = filtered_eval_meta
    # # ==============================================
        
    # breakpoint()
    # print(eval_meta)
        
    #######################################################
    # setup algorithm (model)
    #######################################################
    #print(dataset.num_classes,'&&&&&&&&&&')
    algorithm = algorithm_class(
        dataset.input_shape,
        dataset.num_classes,
        len(test_envs),
        hparams,
    )

    algorithm.to(device)

    n_params = sum([p.numel() for p in algorithm.parameters()])
    logger.info("# of params = %d" % n_params)

    train_minibatches_iterator = zip(*train_loaders)
    checkpoint_vals = collections.defaultdict(lambda: [])

    #######################################################
    # start training loop
    #######################################################
    evaluator = Evaluator(
        train_envs,
        eval_meta,
        n_envs,
        logger,
        device = args.device,
        evalmode=args.evalmode,
        debug=args.debug,
        target_env=target_env,
    )

    swad = None
    if hparams["swad"]:
        swad_algorithm = swa_utils.AveragedModel(algorithm)
        swad_cls = getattr(swad_module, "LossValley")
        swad = swad_cls(evaluator, **hparams.swad_kwargs)

    last_results_keys = None
    records = []
    epochs_path = args.out_dir / "results.jsonl"

    for step in range(n_steps):
        step_start_time = time.time()
        # batches_dictlist: [{env0_data_key: tensor, env0_...}, env1_..., ...]
        batches_dictlist = next(train_minibatches_iterator)
        # batches: {data_key: [env0_tensor, ...], ...}
        batches = misc.merge_dictlist(batches_dictlist)
        # to device
        #breakpoint()
        batches = {
            key: [tensor.to(device) for tensor in tensorlist] for key, tensorlist in batches.items()
        }
        #breakpoint()
        inputs = {**batches, "step": step}
        #breakpoint()
        step_vals = algorithm.update(**inputs)

        for key, val in step_vals.items():
            checkpoint_vals[key].append(val)
        checkpoint_vals["step_time"].append(time.time() - step_start_time)

        if swad:
            # swad_algorithm is segment_swa for swad
            swad_algorithm.update_parameters(algorithm, step=step)

        if step % checkpoint_freq == 0:
            results = {
                "step": step,
                "epoch": step / steps_per_epoch,
            }

            for key, val in checkpoint_vals.items():
                results[key] = np.mean(val)

            eval_start_time = time.time()
               
            accuracies, summaries = evaluator.evaluate(algorithm)
            results["eval_time"] = time.time() - eval_start_time

            # results = (epochs, loss, step, step_time, eval_time)
            results_keys = list(summaries.keys()) + sorted(accuracies.keys()) + list(results.keys())
            # merge results
            results.update(summaries)
            results.update(accuracies)

            # print
            if results_keys != last_results_keys:
                logger.info(misc.to_row(results_keys))
                last_results_keys = results_keys
            logger.info(misc.to_row([results[key] for key in results_keys]))
            records.append(copy.deepcopy(results))

            # all_candidate_keys = list(summaries.keys()) + sorted(accuracies.keys()) + list(results.keys())
            
            # keys_to_hide = [
            #     'test_in', 'test_out', 'train_in', 
            #     'EoE', 'Tauc', 'Kauc', 'ExE', 'ML', 'conf', 'ECE', 'NLL', 'PRR'
            # ]

            # results_keys = []
            # for key in all_candidate_keys:
            #     val = results.get(key, 0)
            #     if key in keys_to_hide and (isinstance(val, (int, float)) and abs(val) < 1e-6):
            #         continue
            #     results_keys.append(key)

            # if results_keys != last_results_keys:
            #     logger.info(misc.to_row(results_keys))
            #     last_results_keys = results_keys
            
            # logger.info(misc.to_row([results[key] for key in results_keys]))

            # update results to record
            results.update({"hparams": dict(hparams), "args": vars(args)})

            with open(epochs_path, "a") as f:
                f.write(json.dumps(results, sort_keys=True, default=json_handler) + "\n")

            checkpoint_vals = collections.defaultdict(lambda: [])

            writer.add_scalars_with_prefix(summaries, step, f"{testenv_name}/summary/")
            writer.add_scalars_with_prefix(accuracies, step, f"{testenv_name}/all/")

            if args.model_save and step >= args.model_save:
                ckpt_dir = args.out_dir / "checkpoints"
                ckpt_dir.mkdir(exist_ok=True)

                test_env_str = ",".join(map(str, test_envs))
                filename = "TE{}_{}.pth".format(test_env_str, step)
                if len(test_envs) > 1 and target_env is not None:
                    train_env_str = ",".join(map(str, train_envs))
                    filename = f"TE{target_env}_TR{train_env_str}_{step}.pth"
                path = ckpt_dir / filename

                save_dict = {
                    "args": vars(args),
                    "model_hparams": dict(hparams),
                    "test_envs": test_envs,
                    "model_dict": algorithm.cpu().state_dict(),
                }
                algorithm.to(device)
                if not args.debug:
                    torch.save(save_dict, path)
                else:
                    logger.debug("DEBUG Mode -> no save (org path: %s)" % path)
            #==================================
            # infer and save teacher predicts
            #==================================
            envs = test_envs + train_envs
            envs = sorted(envs)
            # print(envs)
            # breakpoint()
            # envs = test_envs
            if step == n_steps-1:
                logger.info("begin infer.....")
                path1 = os.path.join(args.data_dir2, dataset.environments[train_envs[0]])
                if not os.path.exists(path1):
                    os.makedirs(path1)
                for i in envs:
                    predict_lists = []
                    dataset, in_splits, out_splits = get_dataset(
                        i, args, hparams, algorithm_class, infer=True
                        )
                    #breakpoint()
                    dataset_test = in_splits[0][0]
                    domain_name = dataset.environments[i]
                    batch_size = 256
                    num_batches = len(dataset_test) // batch_size + (1 if len(dataset_test)%batch_size != 0 else 0)
                    algorithm.cpu()
                    algorithm.eval()
                    with torch.no_grad():
                        for batch_idx in range(num_batches):
                            start_idx = batch_idx * batch_size
                            end_idx = min((batch_idx + 1) * batch_size, len(dataset_test))
                            batch_input_tensor = [dataset_test[idx]["x"] for idx in range(start_idx, end_idx)]
                            batch_input_tensor = torch.stack(batch_input_tensor, 0)  
                            print(f"Here STEP1~ Batch {batch_idx+1}/{num_batches}")  
                            predict_rst = algorithm.predict(batch_input_tensor)# shape==[pic_nums, class_nums] like [2048, 7]
                            # print(f"Here STEP2~ Batch {batch_idx+1}/{num_batches}")
                            predict_lists.append(predict_rst)
                    predict_rst = torch.cat(predict_lists, 0)
                    path = os.path.join(path1, domain_name)
                    print("Here STEP2~")
                    if not os.path.exists(path):
                        os.makedirs(path)
                    torch.save(predict_rst, f"{path}/tensor.pt")
                    logger.info(f"{i}th envs save done!")
                logger.info("save done!")
            algorithm.to(device)
            #==================================
            # infer and save done!
            #==================================
            # swad
            if swad:
                def prt_results_fn(results, avgmodel):
                    step_str = f" [{avgmodel.start_step}-{avgmodel.end_step}]"
                    row = misc.to_row([results[key] for key in results_keys if key in results])
                    logger.info(row + step_str)

                swad.update_and_evaluate(
                    swad_algorithm, results["train_out"], results["tr_outloss"], prt_results_fn
                )

                if hasattr(swad, "dead_valley") and swad.dead_valley:
                    logger.info("SWAD valley is dead -> early stop !")
                    break

                swad_algorithm = swa_utils.AveragedModel(algorithm)  # reset

        if step % args.tb_freq == 0:
            # add step values only for tb log
            writer.add_scalars_with_prefix(step_vals, step, f"{testenv_name}/summary/")
        
                

    # find best
    logger.info("---")
    records = Q(records)
    te_val_best = records.argmax("test_out")["test_in"]
    tr_val_best = records.argmax("train_out")["test_in"]
    last = records[-1]["test_in"]

    in_key = "train_out"
    tr_val_best_indomain = records.argmax("train_out")[in_key]
    last_indomain = records[-1][in_key]

    # NOTE for clearity, report only training-domain validation results.
    ret = {
        #  "test-domain validation": te_val_best,
        "training-domain validation": tr_val_best,
        #  "last": last,
        #  "last (inD)": last_indomain,
        #  "training-domain validation (inD)": tr_val_best_indomain,
    }

    # Evaluate SWAD
    if swad:
        swad_algorithm = swad.get_final_model()
        if hparams["freeze_bn"] is False:
            n_steps = 500 if not args.debug else 10
            logger.warning(f"Update SWAD BN statistics for {n_steps} steps ...")
            swa_utils.update_bn(train_minibatches_iterator, swad_algorithm, n_steps)

        logger.warning("Evaluate SWAD ...")
        accuracies, summaries = evaluator.evaluate(swad_algorithm)
        results = {**summaries, **accuracies}
        start = swad_algorithm.start_step
        end = swad_algorithm.end_step
        step_str = f" [{start}-{end}]  (N={swad_algorithm.n_averaged})"
        row = misc.to_row([results[key] for key in results_keys if key in results]) + step_str
        logger.info(row)

        ret["SWAD"] = results["test_in"]
        ret["SWAD (inD)"] = results[in_key]

    for k, acc in ret.items():
        logger.info(f"{k} = {acc:.3%}")

    return ret, records
