

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import random
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from datetime import datetime
from tasks.task_generator import configure_dataset
from utils.utils import *
import wandb
import os 
from agent import *
from collections import deque
import hydra
from omegaconf import DictConfig
import matplotlib.pyplot as plt



def evaluate_one_task(model, train_loader, test_loader, device):
    with torch.no_grad():
        # Evaluate on test set
        model.eval()
        train_correct = 0
        train_total = 0
        with torch.no_grad():
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
            
                logits = model.predict(data)

                preds = logits.argmax(dim=1)
                train_correct += preds.eq(target).sum().item()
                train_total += data.size(0)
        model.eval()
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                logits = model.predict(data)
                preds = logits.argmax(dim=1)
                test_correct += preds.eq(target).sum().item()
                test_total += data.size(0)
    model.train()

    return train_correct / train_total, test_correct / test_total

# ------------------------------------------------------------------------
# 8) Training Routines
# ------------------------------------------------------------------------
def train_one_task(task_id, model, eval_model, train_loader, test_loader, device, epochs, reoccured, agent_config, task_config, global_counter, wandb_enable, intermediate_epochs=True):
    """
    Train a given model on one random-label CIFAR10 task for 'epochs' epochs,
    then return the final test accuracy.
    """
    model.to(device)
    model.train()
    counter = 0
    quarter_train_acc = 0
    half_train_acc = 0
    three_fourth_train_acc = 0
    quarter_test_acc = 0
    half_test_acc = 0
    three_fourth_test_acc = 0
    final_train_acc = 0
    final_test_acc = 0
    few_shot_train_acc = 0
    few_shot_test_acc = 0
    with tqdm(range(epochs), desc=f'training the model task {task_id}', disable=False) as pbar:
        for ep in range(epochs):
            model.to(device)
            for data, target in tqdm(train_loader, desc= 'training the model in one epoch', disable=True):
                data, target = data.to(device), target.to(device)
                logits, metrics = model.step(data, target)
                loss = metrics['curr_train_loss']
                
                if wandb_enable:
                    log_info = {f'training_loss_{task_config.benchmark}':loss.item()}
                    wandb.log(log_info, step=global_counter)
                
                counter += 1
            
            if reoccured and ep == 1:
                if agent_config.agent_type == 'NeuroSyncAgent':
                        eval_model.load_state_dict(model.state_dict())
                        eval_model = eval_model.to(device)
                else:
                    eval_model = model
                eval_model.eval()
                few_shot_train_acc, few_shot_test_acc = evaluate_one_task(eval_model, train_loader=train_loader, test_loader=test_loader, device=device)
            
            if intermediate_epochs:
                if ep == int(epochs * 0.25):
                    if agent_config.agent_type == 'NeuroSyncAgent':
                            eval_model.load_state_dict(model.state_dict())
                            eval_model = eval_model.to(device)
                    else:
                        eval_model = model
                    eval_model.eval()
                    quarter_train_acc, quarter_test_acc = evaluate_one_task(eval_model, train_loader=train_loader, test_loader=test_loader, device=device)
                elif ep == int(epochs * 0.5):
                    if agent_config.agent_type == 'NeuroSyncAgent':
                            eval_model.load_state_dict(model.state_dict())
                            eval_model = eval_model.to(device)
                    else:
                        eval_model = model
                    eval_model.eval()
                    half_train_acc, half_test_acc = evaluate_one_task(eval_model, train_loader=train_loader, test_loader=test_loader, device=device)
                elif ep == int(epochs * 0.75):
                    if agent_config.agent_type == 'NeuroSyncAgent':
                            eval_model.load_state_dict(model.state_dict())
                            eval_model = eval_model.to(device)
                    else:
                        eval_model = model
                    eval_model.eval()
                    three_fourth_train_acc, three_fourth_test_acc = evaluate_one_task(eval_model, train_loader=train_loader, test_loader=test_loader, device=device)

            if wandb_enable:
                if agent_config.agent_type == 'NeuroSyncAgent':
                    log_info_neuro_sync = model.model.plot_params()
                    wandb.log(log_info_neuro_sync, step=global_counter)

            global_counter += 1

    final_train_acc, final_test_acc = evaluate_one_task(model, train_loader=train_loader, test_loader=test_loader, device=device)

    return few_shot_train_acc, few_shot_test_acc, quarter_train_acc, quarter_test_acc, half_train_acc, half_test_acc, three_fourth_train_acc, three_fourth_test_acc, final_train_acc, final_test_acc, metrics, global_counter
# ------------------------------------------------------------------------
# 9) Main Experiment
# ------------------------------------------------------------------------
@hydra.main(config_path='configs/sl', config_name='config')
def main(cfg: DictConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(cfg.seed)

    agent_config = cfg.agent
    arch_config = cfg.arch
    task_config = cfg.task
    
    get_task_dataset, _, _ = configure_dataset(task_config=task_config, arch_config=arch_config, args=cfg)


    print(f"\n===== Training Model: {cfg.agent} =====")

    # Create a new model
    if cfg.agent.agent_type == "BaseAgent":
        model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        eval_model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        
    elif cfg.agent.agent_type == "L2Agent":
        model = L2Agent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        eval_model = L2Agent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "ReDoAgent":
        model = ReDoAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        eval_model = ReDoAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "L2InitAgent":
        model = L2InitAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
        eval_model = L2InitAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "LayerNormAgent":
        model = LayerNormAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        eval_model = LayerNormAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "NeuroSyncAgent":
        num_foward_pass_each_task = task_config.epochs * (task_config.limit / task_config.batch_size)
        agent_config.ema_decay = float(np.power(task_config.ema_target, 1/num_foward_pass_each_task))
        #print('ema_decay:', agent_config.ema_decay)
        model = NeuroSyncAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
        eval_model = NeuroSyncAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device, track_params=False)
    
    elif cfg.agent.agent_type == "DeepFourierAgent":
        model = DeepFourierAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        eval_model = DeepFourierAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CReLUAgent":
        model = CReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        eval_model = CReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "PReLUAgent":
        model = PReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        eval_model = PReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CBPAgent":
        model = CBPAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
        eval_model = CBPAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "EWCAgent":
        model = EWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
        eval_model = EWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "L2InitPlusEWCAgent":
        model = L2InitPlusEWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
        eval_model = L2InitPlusEWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "ShrinkAndPerturbAgent":
        model = ShrinkAndPerturbAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
        eval_model = model
    
    else:
        raise ValueError(f"Unknown agent: {cfg.agent}")
    
    if cfg.agent == "NeuroSyncAgent":
        orchs_param_count, network_param_count = model.model.compute_total_params()
        print('Total number of trainable parameters:', orchs_param_count + network_param_count)
        print('Total number of trainable parameters in the network:', network_param_count)
        print('Percentage of trainable parameters in the network:', network_param_count / (orchs_param_count + network_param_count) * 100)
    else:
        print('Total number of trainable parameters:', model.model.compute_total_params())

    model.to(device)
    global_counter = 0
    exp_name = f'{model.__class__.__name__}_{arch_config.arch_name}_{agent_config.optimizer}_{agent_config.lr}_{task_config.benchmark}_transform={task_config.tranform}_{cfg.seed}'
    
    if agent_config.agent_type == 'BaseAgent' and agent_config.reset_network:
        exp_name = 'Scratch_' + exp_name
        
    if cfg.wandb:
        merged_config = {}
        merged_config.update(namespace_to_dict(task_config))
        merged_config.update(namespace_to_dict(agent_config))
        merged_config.update(namespace_to_dict(arch_config))
        merged_config['seed'] = cfg.seed  # keep args fields if you like
        wandb.init(
            project=cfg.proj_name,
            name= exp_name,
            group= task_config.benchmark,
            config=merged_config,
            save_code=True,
            #sync_tensorboard=True,
        )

    bt_test_acc = np.zeros((task_config.num_tasks))
    bt_train_acc = np.zeros((task_config.num_tasks))
    
    one_epoch_bt_train_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    one_epoch_bt_test_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    
    two_epoch_bt_train_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    two_epoch_bt_test_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    
    three_epoch_bt_train_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    three_epoch_bt_test_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    
    four_epoch_bt_train_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    four_epoch_bt_test_matrix = np.zeros((task_config.num_tasks, task_config.num_tasks))
    
    
    train_loader_list = []
    test_load_list = []
    for task_id in range(task_config.num_tasks):
        train_dataset_task, test_dataset_task, reoccured = get_task_dataset(task_id)

        g = torch.Generator()
        g.manual_seed(cfg.seed)

        if not isinstance(train_dataset_task, DataLoader):
            train_loader = DataLoader(train_dataset_task, batch_size=task_config.batch_size, shuffle=True, generator= g)
        else:
            train_loader = train_dataset_task
        if not isinstance(test_dataset_task, DataLoader):
            test_loader = DataLoader(test_dataset_task, batch_size=task_config.batch_size, shuffle=False)
        else:
            test_loader = test_dataset_task
        
        train_loader_list.append(train_loader)
        test_load_list.append(test_loader)

        zero_shot_train_acc, zero_shot_test_acc = evaluate_one_task(model, train_loader, test_loader, device)
       
        few_shot_train_acc, few_shot_test_acc, quarter_train_acc, quarter_test_acc, half_train_acc, half_test_acc, three_fourth_train_acc, three_fourth_test_acc, final_train_acc, final_test_acc, metrics, global_counter = train_one_task(task_id, model, eval_model, train_loader, test_loader, device, epochs=task_config.epochs, reoccured= reoccured, agent_config=agent_config, task_config=task_config, global_counter=global_counter, wandb_enable=cfg.wandb)
        #breakpoint()
        bt_test_acc[task_id] = final_test_acc
        bt_train_acc[task_id] = final_train_acc

        if cfg.agent.agent_type == 'BaseAgent':
            if agent_config.reset_network:
                model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
                model.to(device=device)
                eval_model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)

        model.get_optimizer()

        if cfg.agent.agent_type in ['EWCAgent', 'L2InitPlusEWCAgent']:

            # Get completed task's test data.
            if isinstance(train_dataset_task, DataLoader):
                xs_list = []
                ys_list = []
                for batch in loader:
                    x, y = batch
                    xs_list.append(x)
                    ys_list.append(y)

                task_test_xs = torch.cat(xs_list, dim=0)   # shape: (N, ...)
                task_test_ys = torch.cat(ys_list, dim=0)   # shape: (N, ...)
            else:
                loader = DataLoader(train_dataset_task, batch_size=len(train_dataset_task), shuffle=False)
                task_test_xs, task_test_ys = next(iter(loader))

            # Shuffle data.
            dataset_len = len(task_test_xs)
            indices = np.arange(dataset_len)
            np.random.shuffle(indices)
            task_test_xs = task_test_xs[indices][:int(dataset_len/10)]
            task_test_ys = task_test_ys[indices][:int(dataset_len/10)]

            # Update Fisher matrix using this data.
            model.update_params_and_fisher(
                task_test_xs, task_test_ys, batch_size=task_config.batch_size)
        
        log_info = {}
        
        # try:
        #     log_info_neuro_sync = model.model.plot_params()
        #     log_info = log_info_neuro_sync
        # except:
        #     print('*************************************')
        
        if cfg.monitor_backward_transfer:
            for i_task in range(task_id):
                
                g = torch.Generator()
                g.manual_seed(cfg.seed)
                train_loader = train_loader_list[i_task]
                test_loader = test_load_list[i_task]
                
                eval_model.load_state_dict(model.state_dict())
                eval_model.get_optimizer()
                eval_model = eval_model.to(device)
                
                if agent_config.agent_type == 'NeuroSyncAgent':
                    eval_eval_model = NeuroSyncAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device, track_params=False)
                    eval_eval_model.load_state_dict(eval_model.state_dict())
                    eval_eval_model = eval_eval_model.to(device)
                else:
                    eval_eval_model = eval_model
                    
                _, _, one_epoch_train_acc, one_epoch_test_acc, two_epoch_train_acc, two_epoch_test_acc, three_epoch_train_acc, three_epoch_test_acc, four_epoch_train_acc, four_epoch_test_acc, _, _ = train_one_task(i_task, eval_model, eval_eval_model, train_loader, test_loader,
                                device, 4, False, agent_config, task_config, global_counter,
                                False, True)

                one_epoch_bt_train_matrix[task_id, i_task] = one_epoch_train_acc 
                one_epoch_bt_test_matrix[task_id, i_task] = one_epoch_test_acc 
                
                two_epoch_bt_train_matrix[task_id, i_task] = two_epoch_train_acc 
                two_epoch_bt_test_matrix[task_id, i_task] = two_epoch_test_acc 
                
                three_epoch_bt_train_matrix[task_id, i_task] = three_epoch_train_acc 
                three_epoch_bt_test_matrix[task_id, i_task] = three_epoch_test_acc 
                
                four_epoch_bt_train_matrix[task_id, i_task] = four_epoch_train_acc 
                four_epoch_bt_test_matrix[task_id, i_task] = four_epoch_test_acc 
        
        if reoccured:
            log_info[f'reoccuring/train_acc_{task_config.benchmark}'] = final_train_acc
            log_info[f'reoccuring/test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'reoccuring/quarter_train_acc_{task_config.benchmark}'] = quarter_train_acc
            log_info[f'reoccuring/quarter_test_acc_{task_config.benchmark}'] = quarter_test_acc
            log_info[f'reoccuring/half_train_acc_{task_config.benchmark}'] = half_train_acc
            log_info[f'reoccuring/half_test_acc_{task_config.benchmark}'] = half_test_acc
            log_info[f'reoccuring/three_fourth_train_acc_{task_config.benchmark}'] = three_fourth_train_acc
            log_info[f'reoccuring/three_fourth_test_acc_{task_config.benchmark}'] = three_fourth_test_acc
            log_info[f'reoccuring/zero_shot_test_acc_{task_config.benchmark}'] = zero_shot_test_acc
            log_info[f'reoccuring/zero_shot_train_acc_{task_config.benchmark}'] = zero_shot_train_acc
            log_info[f'reoccuring/few_epoch_train_acc_{task_config.benchmark}'] = few_shot_train_acc
            log_info[f'reoccuring/few_epoch_test_acc_{task_config.benchmark}'] = few_shot_test_acc
        
        # TODO do something with the metrics
        if cfg.wandb:
            #breakpoint()
            log_info[f"train_acc_{task_config.benchmark}"] = final_train_acc
            log_info[f'test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'quarter_train_acc_{task_config.benchmark}'] = quarter_train_acc
            log_info[f'quarter_test_acc_{task_config.benchmark}'] = quarter_test_acc
            log_info[f'half_train_acc_{task_config.benchmark}'] = half_train_acc
            log_info[f'half_test_acc_{task_config.benchmark}'] = half_test_acc
            log_info[f'three_fourth_train_acc_{task_config.benchmark}'] = three_fourth_train_acc
            log_info[f'three_fourth_test_acc_{task_config.benchmark}'] = three_fourth_test_acc
            log_info[f'zero_shot_train_acc_{task_config.benchmark}'] = zero_shot_train_acc
            log_info[f'zero_shot_test_acc_{task_config.benchmark}'] = zero_shot_test_acc
            
            wandb.log(log_info, step=global_counter)


        print(f" Task {task_id+1}/{task_config.num_tasks} -> test acc: {final_train_acc:.3f}")

    # 1 epoch
    table1 = wandb.Table(
        data=one_epoch_bt_train_matrix.tolist(),
        columns=[f"col_{i}" for i in range(one_epoch_bt_train_matrix.shape[1])]
    )
    
    table2 = wandb.Table(
        data=one_epoch_bt_test_matrix.tolist(),
        columns=[f"col_{i}" for i in range(one_epoch_bt_test_matrix.shape[1])]
    )
    
    wandb.log({"one_epoch_bt_train_matrix": table1})
    wandb.log({"one_epoch_bt_test_matrix": table2})
    
    # 2 epoch
    table1 = wandb.Table(
        data=two_epoch_bt_train_matrix.tolist(),
        columns=[f"col_{i}" for i in range(two_epoch_bt_train_matrix.shape[1])]
    )
    
    table2 = wandb.Table(
        data=two_epoch_bt_test_matrix.tolist(),
        columns=[f"col_{i}" for i in range(two_epoch_bt_test_matrix.shape[1])]
    )
    
    wandb.log({"two_epoch_bt_train_matrix": table1})
    wandb.log({"two_epoch_bt_test_matrix": table2})
    
    # 3 epoch
    table1 = wandb.Table(
        data=three_epoch_bt_train_matrix.tolist(),
        columns=[f"col_{i}" for i in range(three_epoch_bt_train_matrix.shape[1])]
    )
    
    table2 = wandb.Table(
        data=three_epoch_bt_test_matrix.tolist(),
        columns=[f"col_{i}" for i in range(three_epoch_bt_test_matrix.shape[1])]
    )
    
    wandb.log({"three_epoch_bt_train_matrix": table1})
    wandb.log({"three_epoch_bt_test_matrix": table2})
    
    # 4 epoch
    table1 = wandb.Table(
        data=four_epoch_bt_train_matrix.tolist(),
        columns=[f"col_{i}" for i in range(four_epoch_bt_train_matrix.shape[1])]
    )
    
    table2 = wandb.Table(
        data=four_epoch_bt_test_matrix.tolist(),
        columns=[f"col_{i}" for i in range(four_epoch_bt_test_matrix.shape[1])]
    )
    
    wandb.log({"four_epoch_bt_train_matrix": table1})
    wandb.log({"four_epoch_bt_test_matrix": table2})
    
    # 1 epoch
    
    # Plot a heatmap
    fig1, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(one_epoch_bt_train_matrix, aspect='auto')
    fig1.colorbar(cax)

    # Plot a heatmap
    fig2, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(one_epoch_bt_test_matrix, aspect='auto')
    fig2.colorbar(cax)
    
    # Log the figure
    wandb.log({"one_epoch_train_matrix_heatmap": wandb.Image(fig1)})
    wandb.log({"one_epoch_test_matrix_heatmap": wandb.Image(fig2)})
    
    # 2 epoch
    
    # Plot a heatmap
    fig1, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(two_epoch_bt_train_matrix, aspect='auto')
    fig1.colorbar(cax)

    # Plot a heatmap
    fig2, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(two_epoch_bt_test_matrix, aspect='auto')
    fig2.colorbar(cax)
    
    # Log the figure
    wandb.log({"two_epoch_train_matrix_heatmap": wandb.Image(fig1)})
    wandb.log({"two_epoch_test_matrix_heatmap": wandb.Image(fig2)})
    
    # 3 epoch
    
    # Plot a heatmap
    fig1, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(three_epoch_bt_train_matrix, aspect='auto')
    fig1.colorbar(cax)

    # Plot a heatmap
    fig2, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(three_epoch_bt_test_matrix, aspect='auto')
    fig2.colorbar(cax)
    
    # Log the figure
    wandb.log({"three_epoch_train_matrix_heatmap": wandb.Image(fig1)})
    wandb.log({"three_epoch_test_matrix_heatmap": wandb.Image(fig2)})
    
    # 4 epoch
    
    # Plot a heatmap
    fig1, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(four_epoch_bt_train_matrix, aspect='auto')
    fig1.colorbar(cax)

    # Plot a heatmap
    fig2, ax = plt.subplots(figsize=(4,4))
    cax = ax.matshow(four_epoch_bt_test_matrix, aspect='auto')
    fig2.colorbar(cax)
    
    # Log the figure
    wandb.log({"four_epoch_train_matrix_heatmap": wandb.Image(fig1)})
    wandb.log({"four_epoch_test_matrix_heatmap": wandb.Image(fig2)})




if __name__ == "__main__":
    main()





