""" test_model.py
    Test models

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import logging
import os
import sys
from collections import OrderedDict
from tqdm import tqdm
import math
import matplotlib.pyplot as plt


import json

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
import cv2
import numpy as np

import deepthinking as dt
from deepthinking.utils.testing import get_predicted, get_visualizable_pred
from deepthinking.utils.rotation import rotate_batch

# Ignore statements for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115


def create_image(tensor):
    tensor = tensor.reshape((8, 8, 1))
    tensor = tensor.detach().cpu().numpy().astype(np.float64)
    tensor = tensor / np.max(tensor)

    tensor *= 255
    # tensor = np.array(tensor, dtype=np.int64)
    tensor = cv2.resize(tensor.astype(float), (256, 256), interpolation = cv2.INTER_NEAREST)
    return tensor

def test(net, testloader, iters, problem, device, sample_idx=10, visualize_predict=False, visualize_conver=True):
    max_iters = max(iters)
    net.eval()
    net.debug = True
    corrects = torch.zeros(max_iters)

    with torch.no_grad():
        for _ in range(sample_idx):
            inputs, targets = next(iter(testloader))

        inputs, targets = inputs.to(device), targets.to(device) 
        
        # output_dt = net(inputs, iters_to_do=max_iters, debug=True, halt_by_ssl=True)
        output_dt = net(inputs, iters_to_do=max_iters, debug=True)
        if len(output_dt) == 5:
            all_outputs, attention_weights, act_probs, res, norms = output_dt[:]
            attention_weights = [round(x.detach().cpu().numpy().tolist()[0][0], 4) for x in attention_weights]
            # act_probs = [round(x.detach().cpu().numpy().tolist(), 4) for x in act_probs]
            print("Attention Weights: ", attention_weights)
            print("-" * 50)
            print("ACT probs: ", act_probs)
        elif len(output_dt) == 3:
            all_outputs, res, norms = output_dt[:]
            # convert norms to list and save it to json file
            res = [round(x, 4) for x in res]
            with open("diff_norms.json", "w") as f:
                json.dump(res, f)
            print("res: ", res)
        else:
            all_outputs = output_dt
            res, norms = [], []
        
        if visualize_conver:
            # Plot the values
            fig, axs = plt.subplots(1, 2, figsize=(10, 5))  # 1 row, 2 columns
            # Plot the first graph
            axs[0].plot(range(len(res)), res, 'r-', label='List 1')
            axs[0].set_title('Convergence')
            axs[0].set_xlabel('X-axis')
            axs[0].set_ylabel('Y-axis')
            axs[0].legend()

            # Plot the second graph
            axs[1].plot(range(len(norms)), norms, 'b-', label='List 2')
            axs[1].set_title('Norm')
            axs[1].set_xlabel('X-axis')
            axs[1].set_ylabel('Y-axis')
            axs[1].legend()

            # Adding a title
            plt.tight_layout()
            plt.savefig('visualize.png', dpi=300, bbox_inches='tight')

            
        if visualize_predict:
            for i in range(all_outputs.size(1)):
                outputs = all_outputs[:, i]
                predicted = get_predicted(inputs, outputs, problem)
                vis_prob_map = get_visualizable_pred(outputs)
                
                targets = targets.view(targets.size(0), -1)
                corrects[i] += torch.amin(predicted == targets, dim=[1]).sum().item()
                if corrects[i]:
                    print(f"iters: {i}: True")
                else:
                    print(f"iters: {i}: False")
                # attention_weights = attention_debug_list[i].detach().cpu().numpy().astype(np.float64).squeeze(0).tolist()
                # print(attention_weights)
                # attention_weights = [round(x.detach().cpu().numpy().tolist()[0][0], 4) for x in attention_weights]
                # print(attention_weights)
                cv2.imwrite(f"debug/{i}_prob.jpg", vis_prob_map)
                predicted = create_image(predicted)
                cv2.imwrite(f"debug/{i}_pred.jpg", predicted)
                # break
            targets = create_image(targets)
            cv2.imwrite(f"debug/target.jpg", targets)
            
def test_with_ssh(net, testloader, iters, problem, device, sample_idx=1, difficulty="hard"):
    max_iters = max(iters)
    net.eval()
    net.debug = True
    criterion = torch.nn.CrossEntropyLoss(reduction="none")
    with torch.no_grad():
        for _ in range(sample_idx):
            inputs, targets = next(iter(testloader))

        inputs, targets = inputs.to(device), targets.to(device)
        ssh_inputs, ssh_labels = rotate_batch(inputs, 'expand')
        ssh_labels = ssh_labels.view(ssh_labels.size(0), -1)
        targets = targets.view(targets.size(0), -1)
        # ssh_label = torch.zeros(targets.shape[0], 1).long().to(device)
        loss_ssh_list, loss_cls_list = [], []
        # _, all_ssh_outputs = net(ssh_inputs, iters_to_do=max_iters, debug=False, return_ssh=True, ssh_target=ssh_labels)
        _, all_ssh_outputs = net(ssh_inputs, iters_to_do=max_iters, debug=False, return_ssh=True)
        loss_gap_list = []
        for i in range(0, max_iters):
            ssh_out = all_ssh_outputs[:, i]
            ssh_out = ssh_out.view(ssh_out.size(0), ssh_out.size(1), -1)
            loss_ssh = criterion(ssh_out, ssh_labels)
            loss_ssh_list.append(loss_ssh.mean().item())
            if i >= 1:
                loss_gap_list.append(abs(loss_ssh.mean().item() - loss_ssh_list[-2])/abs(loss_ssh.mean().item()))
            
        all_outputs = net(inputs, iters_to_do=max_iters, debug=False, return_ssh=False)
        for i in range(0, max_iters):
            cls_out = all_outputs[:, i]
            cls_out = cls_out.view(cls_out.size(0), cls_out.size(1), -1)
            loss_cls = criterion(cls_out, targets).mean()
            loss_cls_list.append(loss_cls.item())
            # print("ssh loss iter {}: {}\tcls loss iter: {}".format(str(i), str(loss_ssh.item()), str(loss_cls.item())))
        
        # Plot the values
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))  # 1 row, 2 columns
        # Plot the first graph
        axs[0].plot(range(len(loss_ssh_list)), loss_ssh_list, 'r-', label='List 1')
        axs[0].set_title('Self-supervise Loss')
        axs[0].set_xlabel('X-axis')
        axs[0].set_ylabel('Y-axis')
        axs[0].legend()

        # Plot the second graph
        # axs[1].plot(range(len(loss_gap_list)), loss_gap_list, 'b-', label='List 2')
        axs[1].plot(range(len(loss_cls_list)), loss_cls_list, 'b-', label='List 2')
        axs[1].set_title('CLS Loss')
        axs[1].set_xlabel('X-axis')
        axs[1].set_ylabel('Y-axis')
        axs[1].legend()

        # Adding a title
        plt.tight_layout()
        plt.savefig(f'visualize_{difficulty}.png', dpi=300, bbox_inches='tight')

def visualize_immediate_state(net, loader, iters, problem, device, sample_idx=1):
    easy_testloader, medium_testloader, hard_testloader = loader["train"], loader["val"], loader["test"]
    max_iters = max(iters)
    net.eval()
    net.debug = True

    with torch.no_grad():
        count = 0
        for hard_inputs, hard_targets in hard_testloader:
            # es_inputs, es_targets = next(iter(easy_testloader))
            # med_inputs, med_targets = next(iter(medium_testloader))
            # hard_inputs, hard_targets = next(iter(hard_testloader))
            count += 1
            if count >= sample_idx:
                break

        # es_inputs, es_targets = es_inputs.to(device), es_targets.to(device)
        # med_inputs, med_targets = med_inputs.to(device), med_targets.to(device)
        hard_inputs, hard_targets = hard_inputs.to(device), hard_targets.to(device)
        labels = [
        'airplane',
        'automobile',
        'bird',
        'cat',
        'deer',
        'dog',
        'frog',
        'horse',
        'ship',
        'truck'
        ]
        # vis_easy_img = (es_inputs[0].detach().cpu().permute(1, 2, 0) * 255).numpy()
        # cv2.imwrite(f'visualize_easy_sample.png', vis_easy_img)
        # print(f"Label easy: {labels[es_targets[0]]}")
        # vis_med_img = (med_inputs[0].detach().cpu().permute(1, 2, 0) * 255).numpy()
        # cv2.imwrite(f'visualize_med_sample.png', vis_med_img)
        # print(f"Label med: {labels[med_targets[0]]}")
        vis_hard_img = (hard_inputs[0].detach().cpu().permute(1, 2, 0) * 255).numpy()
        cv2.imwrite(f'visualize_hard_sample.png', vis_hard_img)
        print(f"Label hard: {labels[hard_targets[0]]}")
        
        # all_med_outputs = net(med_inputs, iters_to_do=max_iters, debug=False, return_ssh=False)
        # all_med_outputs = torch.nn.Softmax(dim=2)(all_med_outputs)
        
        all_hard_outputs, ssh_outputs = net(hard_inputs, iters_to_do=max_iters, debug=False, return_ssh=True, sample_idx=sample_idx)
        # all_hard_outputs = torch.softmax(all_hard_outputs, dim=2)
        # ssh_outputs = torch.softmax(ssh_outputs, dim=2)
        
        # # hard_predicts = all_hard_outputs.max(dim=2)
        # # Plot probabilities of all classes
        # x = np.arange(max_iters)  # Steps (0 to 99)

        # # Plot the values
        # fig, axs = plt.subplots(1, 2, figsize=(14, 5))  # 1 row, 2 columns
        
        # # axs[0].figure(figsize=(12, 8))
        # for i in range(10):  # Iterate over 10 classes
        #     axs[0].plot(x, all_hard_outputs[0, :, i].cpu(), label=labels[i])

        # # Add labels, legend, and title
        # axs[0].set_xlabel('Steps')
        # axs[0].set_ylabel('Likelihood')
        # axs[0].set_title('Likelihood of All 10 Classes Over Steps')
        # axs[0].legend(loc='upper right', bbox_to_anchor=(1.2, 1))
        # axs[0].grid(True)
        # # axs[0].tight_layout()

        # # Save and show the plot
        # plt.savefig(f'all_classes_likelihood_{sample_idx}.png', dpi=300)
        
        # _, ssh_labels = rotate_batch(hard_inputs, 'rand')
        # ssh_labels = ssh_labels.view(ssh_labels.size(0), -1)
        # ssh_labels = torch.zeros_like(ssh_labels).to(ssh_labels.device)
        # _, all_ssh_outputs = net(hard_inputs, iters_to_do=max_iters, debug=False, return_ssh=True, sample_idx=sample_idx)
        # ssh_outputs = torch.softmax(all_ssh_outputs, dim=2)
        
        # ssh_classes = [0, 90, 180, 270]
        
        # for i in range(4):  # Iterate over 10 classes
        #     axs[1].plot(x, ssh_outputs[0, :, i].cpu(), label=str(ssh_classes[i]))

        # # Add labels, legend, and title
        # axs[1].set_xlabel('Steps')
        # axs[1].set_ylabel('Likelihood')
        # axs[1].set_title('Likelihood of All 4 Rotation Classes Over Steps')
        # axs[1].legend(loc='upper right', bbox_to_anchor=(1.2, 1))
        # axs[1].grid(True)
        # plt.tight_layout()

        # # Save and show the plot
        # plt.savefig('all_classes_likelihood.png', dpi=300)
                

@hydra.main(config_path="config", config_name="test_model_config")
def main(cfg: DictConfig):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.backends.cudnn.benchmark = True
    if cfg.problem.hyp.save_period is None:
        cfg.problem.hyp.save_period = cfg.problem.hyp.epochs
    log = logging.getLogger()
    log.info("\n_________________________________________________\n")
    log.info("test_model.py main() running.")
    log.info(OmegaConf.to_yaml(cfg))

    training_args = OmegaConf.load(os.path.join(cfg.problem.model.model_path, ".hydra/config.yaml"))
    cfg_keys_to_load = [("hyp", "alpha"),
                        ("hyp", "epochs"),
                        ("hyp", "lr"),
                        ("hyp", "lr_factor"),
                        ("model", "max_iters"),
                        # ("model", "model"),
                        ("hyp", "optimizer"),
                        ("hyp", "train_mode"),
                        ("model", "width")]
    for k1, k2 in cfg_keys_to_load:
        cfg["problem"][k1][k2] = training_args["problem"][k1][k2]
    cfg.problem.train_data = cfg.problem.train_data

    log.info(OmegaConf.to_yaml(cfg))

    ####################################################
    #               Dataset and Network and Optimizer
    cfg.problem.hyp.test_batch_size = 1
    loaders = dt.utils.get_dataloaders(cfg.problem)

    cfg.problem.model.model_path = os.path.join(cfg.problem.model.model_path, "model_best.pth")
    net, start_epoch, optimizer_state_dict = dt.utils.load_model_from_checkpoint(cfg.problem.name,
                                                                                 cfg.problem.model,
                                                                                 device)
    pytorch_total_params = sum(p.numel() for p in net.parameters())
    log.info(f"This {cfg.problem.model.model} has {pytorch_total_params/1e6:0.3f} million parameters.")
    ####################################################

    ####################################################
    #        Test
    log.info("==> Starting testing...")
    if "feedforward" in cfg.problem.model.model:
        test_iterations = [cfg.problem.model.max_iters]
    else:
        test_iterations = list(range(cfg.problem.model.test_iterations["low"],
                                     cfg.problem.model.test_iterations["high"] + 1))
    os.makedirs("debug", exist_ok=True)
    test(net, loaders["test"], test_iterations, cfg.problem.name, device)
    # test_with_ssh(net, loaders["train"], test_iterations, cfg.problem.name, device, difficulty="easy")
    # test_with_ssh(net, loaders["val"], test_iterations, cfg.problem.name, device, difficulty="medium")
    # test_with_ssh(net, loaders["test"], test_iterations, cfg.problem.name, device, difficulty="hard")
    # for i in range(100, 200):
    #     visualize_immediate_state(net, loaders, test_iterations, cfg.problem.name, device, sample_idx=i)
    # visualize_immediate_state(net, loaders, test_iterations, cfg.problem.name, device, sample_idx=20)


if __name__ == "__main__":
    # run_id = dt.utils.generate_run_id()
    run_id = "test_sample"
    sys.argv.append(f"+run_id={run_id}")
    main()
