from utilities.train import *
import json
import numpy as np
import matplotlib.pyplot as plt


def dump_list_of_dicts_to_json(data, path):
    try:
        with open(path, 'w') as file:
            json.dump(data, file, indent=4)  # Use indent for pretty printing
        print(f"Data successfully saved to {path}")
    except Exception as e:
        print(f"An error occurred: {e}")


def save_speed_results(results, args, job_type):
    total_result = {"args": args, "result": results}
    # df = pd.DataFrame(results)
    # path = f"./experiment_results/{job_type.upper()}/{args['dataset']}/{args['model']}/"
    path = os.path.join(
        "./experiment_results", job_type.upper(), args["dataset"], args["model"]
    )
    if not os.path.exists(path):
        os.makedirs(path)
    dir = os.path.join(path, f"acc_per_iter.json")
    dump_list_of_dicts_to_json(total_result, dir)


def yield_json_results(args, job: str, prediction_targets):
    for root, dirs, files in os.walk(
            os.path.join("experiment_results", job.lower(), args["dataset"].lower())):
        for file in files:
            if file.endswith(".json"):
                model_name = root.split(os.path.sep)[-1]
                # Load the JSON file as a Python dictionary
                json_path = os.path.join(root, file)
                with open(json_path, 'r') as f:
                    model_json_dict = json.load(f)
                if job == 'mtl':
                    df = result_dict_to_df_mtl(model_json_dict, prediction_targets)
                else:
                    df = result_dict_to_df_cl(model_json_dict)
                # avg_acc_df = df_to_avg_acc(df)
                # var_acc_df = df_to_var_acc(df)
                # Yield model name and the average accuracies
                yield model_name, df


def result_dict_to_df_cl(json_dict):
    # Extract the results and ignore 'args'
    results = json_dict["result"]
    rows = []
    # Function to unravel the structure and add it to rows
    for experiment in results:
        seed = experiment["seed"]
        task_index = experiment["task_index"]
        iterations = experiment["iterations"]
        accuracies = experiment["accuracies"]

        for i, iteration in enumerate(iterations):
            row = {
                "seed": seed,
                "task_index": task_index,
                "iteration": iteration
            }
            for accuracy_type, values in accuracies.items():
                row[accuracy_type] = values[i]
            rows.append(row)

    # Create a DataFrame from the rows
    df = pd.DataFrame(rows)
    return df


def result_dict_to_df_mtl(json_dict, prediction_targets):
    # Extract the results and ignore 'args'
    results = json_dict["result"]
    rows = []
    # Function to unravel the structure and add it to rows
    for experiment in results:
        seed = experiment["seed"]
        iterations = experiment["iterations"]
        accuracies = experiment["accuracies"]

        for task_index in range(len(prediction_targets)):
            for i, iteration in enumerate(iterations):
                row = {
                    "seed": seed,
                    "task_index": task_index,
                    "iteration": iteration
                }
                for task_name in prediction_targets[:task_index + 1]:
                    row[task_name] = accuracies[task_name][i]
                rows.append(row)

    # Create a DataFrame from the rows
    df = pd.DataFrame(rows)
    return df


def df_to_avg_acc(result_df: pd.DataFrame):
    return result_df.groupby(by=["task_index", "iteration"], as_index=False).mean()


def df_to_var_acc(result_df: pd.DataFrame):
    return result_df.groupby(by=["task_index", "iteration"], as_index=False).var()


def visualize_speed_results(args, task_index, prediction_targets):
    plt.figure(figsize=(10, 6))

    predefined_colors = {
        "lwp": "#2196F3",  # Muted Blue-Grey (ours)
        "er": "#A5D6A7",  # Muted Green-Grey
        "der": "#FFAB91",  # Muted Orange-Grey
        "lwf": "#FFE082",  # Muted Yellow-Grey
        "mtl": "#CE93D8",  # Muted Purple-Grey
        "nashmtl": "#80DEEA",  # Muted Cyan-Grey
        # "fdr": "#C5E1A5",  # Muted Light Green-Grey
        # "#FFF59D",  # Muted Bright Yellow-Grey
        "fdr": "#FFCC80",  # Muted Orange-Grey
    }

    for job in ['cl', 'mtl']:
        for model_name, df in yield_json_results(args, job=job, prediction_targets=prediction_targets):
            if model_name in ['fdr', 'lwp', 'der', 'mtl', 'nashmtl']:
                color = predefined_colors[model_name]
                available_tasks = prediction_targets[:task_index + 1]
                df = df[df['task_index'] == task_index].copy()

                df['avg_acc'] = df[available_tasks].mean(skipna=True, axis=1)
                acc_mean = df.groupby(by=["task_index", "iteration"], as_index=False).mean()
                acc_std = df.groupby(by=["task_index", "iteration"], as_index=False).std()['avg_acc']

                iterations = acc_mean['iteration']
                acc_mean = acc_mean['avg_acc']

                acc_mean = acc_mean.rolling(5, min_periods=1).mean()
                acc_std = acc_std.rolling(5, min_periods=1).mean()

                if model_name == 'lwp':
                    plt.plot(iterations, acc_mean, label=f"{model_name}", c=color)
                    plt.fill_between(iterations, acc_mean + acc_std, acc_mean - acc_std,
                                     color=color, alpha=0.2)
                else:
                    plt.plot(iterations, acc_mean, label=f"{model_name}", c=color)
                    plt.fill_between(iterations, acc_mean + acc_std, acc_mean - acc_std,
                                     alpha=0.2, color=color)

                # plt.errorbar(iterations[::5],
                #              acc_mean_across_task[::5],
                #              yerr=avg_std[::5],
                #              label=f"{model_name}",
                #              capsize=3)

        # Add a horizontal line at y=0.8 to represent the convergence threshold


    # Set labels and title
    plt.xlabel("Iterations")
    plt.ylabel("Accuracy")
    plt.title(f"Accuracy Convergence per Model, Task Index {task_index}")

    if task_index == 0:
        plt.axhline(y=0.80, color='r', linestyle='--', label="Target Accuracy")
        plt.xlim(0, 90)
        plt.ylim(0.55, 1)
    elif task_index == 1:
        plt.axhline(y=0.85, color='r', linestyle='--', label="Target Accuracy")
        plt.xlim(0, 90)
        plt.ylim(0.55, 1)
    elif task_index == 2:
        plt.axhline(y=0.87, color='r', linestyle='--', label="Target Accuracy")
        plt.xlim(0, 90)
        plt.ylim(0.55, 1)

    plt.tight_layout()
    plt.grid()
    # Add a legend to show which line corresponds to which prediction target
    plt.legend()

    plt.savefig(
        f"figures/{args['dataset']}_task_{task_index}_converge_speed.png",
        dpi=300,
    )
    # Display the plot for the current model
    plt.show()


def get_average_accuracy(eval_results, prediction_targets):
    total = 0
    for task_name in prediction_targets:
        total += eval_results[task_name]
    return total / len(prediction_targets)


def train_cl_speed(
        args: Dict[str, Any],
        device: torch.device,
        seeds: List[int],
        logger: Logger = None,
):
    results = []
    init_transform = get_transform(args["augment"], args["input_size"], args["model"])
    encoder_past = load_model(args, device)
    train_dataset_overall, valid_dataset, test_dataset, prediction_targets = (
        prepare_data(args, seeds[0], init_transform, return_datasets=True)
    )
    criterion = torch.nn.CrossEntropyLoss()

    valid_dataloader = DataLoader(
        valid_dataset, batch_size=args["batch_size"], shuffle=False
    )
    custom_transform = get_custom_transform(
        input_size=args["input_size"], dataset_name=args["dataset"]
    )
    for seed in seeds:
        torch.cuda.empty_cache()
        fix_seed(seed)
        model = get_models(
            args,
            encoder_past,
            cls_output_dim=args["cls_output_dim"],
            lr=args["lr"],
            input_size=args["input_size"],
            dataset_name=args["dataset"],
            buffer_size=args["buffer_size"],
            num_tasks=len(prediction_targets),
            z_dim=args["z_dim"],
            n_epochs=args["epochs"],
            device=device
        ).to(device)
        for task_index, task_name in enumerate(prediction_targets):

            if args['split_task']:
                train_dataset = train_dataset_overall.split_data_by_task(
                    task_index
                )
            else:
                train_dataset = train_dataset_overall
            train_dataloader = DataLoader(
                train_dataset, batch_size=args["batch_size"], shuffle=True
            )

            total_loss, time, iterations, accs = train_one_task_cl_speed(
                args,
                task_name,
                task_index,
                prediction_targets,
                model,
                train_dataloader,
                valid_dataloader,
                seed,
                device,
                criterion,
                custom_transform,
            )

            result_per_task = {
                "seed": seed,
                "task_index": task_index,
                "iterations": iterations,
                "accuracies": accs
            }
            results.append(result_per_task)
    save_speed_results(results, args, job_type='cl')


def train_one_task_cl_speed(
        args,
        task_name,
        task_index,
        prediction_targets,
        model: ContinualLearning,
        train_dataloader,
        valid_dataloader,
        seed,
        device,
        criterion,
        custom_transform,
):
    best_loss = float("inf")
    early_stopping_counter = 0
    time = 0
    iter_count = 0
    iterations = []
    tasks_up_to_now = prediction_targets[:task_index + 1]
    accuracy_dict = {task_name: [] for task_name in tasks_up_to_now}
    for epoch in range(args["epochs"]):
        time_start = perf_counter()
        model.train()
        model.begin_task(train_dataloader, task_name, task_index, criterion=criterion)
        # total_loss = 0
        with tqdm(
                total=len(train_dataloader),
                desc=f"[Task {task_index}][Epoch {epoch + 1}/{args['epochs']}][Train]",
        ) as pbar:
            for sample in train_dataloader:
                model.train()
                image = sample["image"].to(device)
                cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
                tqdm_loss = model.compute_loss(
                    image,
                    cur_task_y,
                    image,
                    criterion,
                    custom_transform,
                    task_index,
                )
                pbar.set_postfix(Loss=tqdm_loss)
                pbar.update(1)
                if iter_count % args['eval_period'] == 0:
                    results = model.calculate_accuraciess(valid_dataloader,
                                                          tasks_up_to_now,
                                                          device)
                    iterations.append(iter_count)
                    for task_name in tasks_up_to_now:
                        accuracy_dict[task_name].append(results[task_name])
                iter_count += 1

        time = (time * (epoch) + (perf_counter() - time_start)) / (epoch + 1)
        total_loss = 0
        model.eval()

        tqdm_loss = total_loss / len(valid_dataloader)
        total_loss += tqdm_loss
        if tqdm_loss < best_loss:
            best_loss = tqdm_loss
            save_model(model, args, seed, task_name)
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
        if early_stopping_counter >= args["early_stopping_tolerance"]:
            break
    model.end_task(train_dataloader, task_name, task_index)

    return total_loss / len(valid_dataloader), time, iterations, accuracy_dict


def train_mtl_speed(
        args: Dict[str, Any],
        device: torch.device,
        seeds: List[int],
):
    all_results = []
    init_transform = get_transform(args["augment"], args["input_size"], args["model"])
    encoder_past = load_model(args, device)
    criterion = torch.nn.CrossEntropyLoss()

    for seed in seeds:
        fix_seed(seed)
        # Prepare the data
        train_dataloader, valid_dataloader, test_dataloader, prediction_targets = prepare_data(
            args, seed, init_transform
        )
        # Initialize model with encoder for multi-task learning
        cur_encoder = copy.deepcopy(encoder_past).to(device)
        model = get_models_MTL(
            args,
            cur_encoder,
            tasks_name_to_cls_num={task_name: 2 for task_name in prediction_targets},
            z_dim=args['z_dim'],
            cls_output_dim=args["cls_output_dim"],
        ).to(device)

        iterations = []
        # list of accuracy per task instead
        accuracy_dict = {task_name: [] for task_name in prediction_targets}
        iter_count = 0
        # Iterate over epochs
        for epoch in range(args["epochs"]):
            train_total_loss = 0
            with tqdm(total=len(train_dataloader),
                      desc=f"[Epoch {epoch + 1}/{args['epochs']}][Train]") as pbar:
                for sample in train_dataloader:
                    image = sample["image"].to(device)
                    tasks_y = {
                        task_name: sample[task_name].type(torch.long).to(device)
                        for task_name in prediction_targets
                    }
                    train_loss = model.compute_loss(image, tasks_y, criterion)
                    train_total_loss += train_loss.item()
                    pbar.set_postfix(train_loss=train_loss.item())
                    pbar.update(1)

                    # Collect accuracies periodically
                    if iter_count % args['eval_period'] == 0:
                        results = model.calculate_accuraciess(
                            valid_dataloader, prediction_targets, device
                        )
                        iterations.append(iter_count)
                        for task_name in prediction_targets:
                            accuracy_dict[task_name].append(results[task_name])

                    iter_count += 1

        result_per_seed = {
            "seed": seed,
            "iterations": iterations,
            "accuracies": accuracy_dict}
        all_results.append(result_per_seed)

    # Save results
    save_speed_results(all_results, args, job_type='mtl')
