

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import random
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from datetime import datetime
from tasks.task_generator import configure_dataset
from utils.utils import *
import wandb
import os 
from agent import *
from collections import deque
import hydra
from omegaconf import DictConfig
import matplotlib.pyplot as plt
from copy import deepcopy
from contextlib import contextmanager
import random, torch, numpy as np
# from hat import HATPayload


def save_artifact(matrix, name, wandb_enable, num_tasks, length):
    # 1) Fill non‐list cells with zero‐lists of length k

    os.makedirs('./forgetting_data/', exist_ok=True)
    
    np.save(f"./forgetting_data/{name}.npy", matrix)

    if wandb_enable:
        artifact = wandb.Artifact(name, type="dataset")
        artifact.add_file(f"./forgetting_data/{name}.npy")

        wandb.log_artifact(artifact)


def evaluate_one_task(model, device, aux_train_loader, aux_test_loader, ema_update=True):

    new_train_loader = aux_train_loader
    new_test_loader = aux_test_loader

    with torch.no_grad():
        # Evaluate on test set
        if not ema_update:
            model.eval()
        train_correct = 0
        train_total = 0
        with torch.no_grad():
            #for data, _, target in new_train_loader:
            for data, target in new_train_loader:
                data, target = data.to(device), target.to(device)
            
                logits = model.predict(data)

                preds = logits.argmax(dim=1)
                train_correct += preds.eq(target).sum().item()
                train_total += data.size(0)
        if not ema_update:
            model.eval()
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            #for data, _, target in new_test_loader:
            for data, target in new_test_loader:
                data, target = data.to(device), target.to(device)
                logits = model.predict(data)
                preds = logits.argmax(dim=1)
                test_correct += preds.eq(target).sum().item()
                test_total += data.size(0)
    model.train()

    return train_correct / train_total, test_correct / test_total



def BWA_one_task(model, cfg, agent_config, arch_config, task_config, device, train_loader, test_loader, epochs, eval_interval):
    """
    Train a given model on one random-label CIFAR10 task for 'epochs' epochs,
    then return the final test accuracy.
    """
    model.to(device)
    model.train()
    batch_counter = 0
    train_reslut = []
    test_result = []
    
    for ep in range(epochs):
        #for data, _, target in train_loader:
        for data, target in train_loader:
            model.to(device)
            model.train()
            data, target = data.to(device), target.to(device)
            logits, metrics = model.step(data, target)
            loss = metrics['curr_train_loss']
            
            if batch_counter % eval_interval == 0:
                eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                eval_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    eval_model.load_consolidated_weights(model.model._ema_params)
                eval_model.to(device)
                eval_model.eval()
                train_acc, test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=train_loader, aux_test_loader=test_loader)
                train_reslut.append(train_acc)
                test_result.append(test_acc)
            
            batch_counter += 1

    return train_reslut, test_result

def train_one_task(task_id, model, cfg, agent_config, arch_config, task_config, device, train_loader, test_loader, epochs, reoccured, global_counter, eval_interval, wandb_enable, aux_train_loader, aux_test_loader):
    """
    Train a given model on one random-label CIFAR10 task for 'epochs' epochs,
    then return the final test accuracy.
    """
    model.to(device)
    model.train()
    few_shot_train_acc = 0
    few_shot_test_acc = 0
    final_train_acc = 0.0
    final_test_acc = 0.0
    
    train_result = []
    test_result = []
    total_batches = int(task_config.limit / task_config.batch_size)
    with tqdm(range(epochs), desc=f'training the model task {task_id}', disable=False) as pbar:
        for ep in range(epochs):
            batch_counter = 0
            model.train()
            #for data, raw_image, target in tqdm(train_loader, desc= 'training the model in one epoch', disable=True):
            for data, target in tqdm(train_loader, desc= 'training the model in one epoch', disable=True):
                data, target = data.to(device), target.to(device)
                
                if isinstance(model, ERWrapper):
                    logits, metrics = model.step(data, target, raw_image=data)
                else:
                    logits, metrics = model.step(data, target)
                
                log_info = {**metrics}
                if isinstance(model, HATAgent):
                    model.current_mask_scale = (1/model.max_mask_scale) + (model.max_mask_scale - 1/model.max_mask_scale) * ((batch_counter-1)/(total_batches-1))
                if batch_counter % eval_interval == 0 and cfg.monitor_forward_transfer:
                    eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                    eval_model.load_state_dict(model.state_dict())
                    if cfg.agent.agent_type == "NeuroSyncAgent":
                        eval_model.load_consolidated_weights(model.model._ema_params) 
                    eval_model.to(device)
                    eval_model.eval()
                    train_acc, test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)
                    if wandb_enable:
                        log_info = {f'train_acc_{task_config.benchmark}':train_acc,
                                    f'test_acc_{task_config.benchmark}':test_acc, **log_info}
                    
                    train_result.append(train_acc)
                    test_result.append(test_acc)
                    
                global_counter += 1
                batch_counter += 1
                
                wandb.log(log_info, step=global_counter)
                
            
            if reoccured and ep == 1:
                eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                eval_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    eval_model.load_consolidated_weights(model.model._ema_params)
                eval_model.eval()
                eval_model.to(device)
                few_shot_train_acc, few_shot_test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)

    final_train_acc, final_test_acc = evaluate_one_task(model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)

    return few_shot_train_acc, few_shot_test_acc, final_train_acc, final_test_acc, train_result, test_result, metrics, global_counter


def create_model(cfg, agent_config, arch_config, task_config, device):
    # Create a new model
    if cfg.agent.agent_type == "BaseAgent":
        model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        
    elif cfg.agent.agent_type == "L2Agent":
        model = L2Agent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "ReDoAgent":
        model = ReDoAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "L2InitAgent":
        model = L2InitAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "LayerNormAgent":
        model = LayerNormAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)

    elif cfg.agent.agent_type == "HATAgent":
        model = HATAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "NeuroSyncAgent":
        if cfg.agent.use_ema_target:
            num_foward_pass_each_task = task_config.epochs * (task_config.limit / task_config.batch_size)
            agent_config.ema_decay = float(np.power(task_config.ema_target, 1/num_foward_pass_each_task))
        print('ema_decay:', agent_config.ema_decay)
        model = NeuroSyncAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "DeepFourierAgent":
        model = DeepFourierAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CReLUAgent":
        model = CReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "PReLUAgent":
        model = PReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CBPAgent":
        model = CBPAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "EWCAgent":
        model = EWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "L2InitPlusEWCAgent":
        model = L2InitPlusEWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "ShrinkAndPerturbAgent":
        model = ShrinkAndPerturbAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    elif cfg.agent.agent_type == "ViTAgent":
        model = ViTAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    else:
        raise ValueError(f"Unknown agent: {cfg.agent}")
    
    if cfg.forgetting_mech:
        assert cfg.agent.agent_type not in ['EWCAgent', 'L2InitPlusEWCAgent', 'HATAgent'] 
        if cfg.er_type == 'er':
            model = ERWrapper(model, buffer_size=cfg.buffer_size, 
                              minibatch_size=int(cfg.buffer_batch_size_ratio * task_config.batch_size),
                              device=device)
        elif cfg.er_type == 'agem':
            model = AGemWrapper(model, buffer_size=cfg.buffer_size, 
                              minibatch_size=int(cfg.buffer_batch_size_ratio * task_config.batch_size),
                              num_tasks=task_config.num_tasks,
                              device=device)
        else:
            print(f'{cfg.er_type} is not supported')
    
    return model

# ------------------------------------------------------------------------
# 9) Main Experiment
# ------------------------------------------------------------------------
@hydra.main(config_path='configs/sl', config_name='config')
def main(cfg: DictConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    set_seed(cfg.seed)

    agent_config = cfg.agent
    arch_config = cfg.arch
    task_config = cfg.task

    get_task_dataset, _, _ = configure_dataset(task_config=task_config, arch_config=arch_config, use_hat=agent_config.agent_type=='HATAgent', args=cfg)

    
    model = create_model(cfg, agent_config, arch_config, task_config, device)

    if not cfg.agent.agent_type == "CBPAgent":
        model.get_optimizer()

    print(f"\n===== Training Model: {cfg.agent} =====")

    
    if cfg.agent == "NeuroSyncAgent":
        orchs_param_count, network_param_count = model.model.compute_total_params()
        print('Total number of trainable parameters:', orchs_param_count + network_param_count)
        print('Total number of trainable parameters in the network:', network_param_count)
        print('Percentage of trainable parameters in the network:', network_param_count / (orchs_param_count + network_param_count) * 100)
    else:
        print('Total number of trainable parameters:', model.model.compute_total_params())

    model.to(device)
    global_counter = 0
    if cfg.forgetting_mech:
        exp_name = f'{model.agent.__class__.__name__}_{cfg.er_type}_{arch_config.arch_name}_{agent_config.optimizer}_{agent_config.lr}_{task_config.benchmark}_transform={task_config.tranform}_{cfg.seed}'
    else:
        exp_name = f'{model.__class__.__name__}_{arch_config.arch_name}_{agent_config.optimizer}_{agent_config.lr}_{task_config.benchmark}_transform={task_config.tranform}_{cfg.seed}'
    
    if agent_config.reset_network:
        exp_name = 'Scratch_' + exp_name
        
    if cfg.wandb:
        merged_config = {}
        merged_config.update(namespace_to_dict(task_config))
        merged_config.update(namespace_to_dict(agent_config))
        merged_config.update(namespace_to_dict(arch_config))
        merged_config['seed'] = cfg.seed  # keep args fields if you like
        wandb.login(key="769ce78fdc904bf194e2ccf5388ba9178218b898")
        wandb.init(
            project=cfg.proj_name,
            name= exp_name,
            group= task_config.benchmark,
            config=merged_config,
            save_code=True,
        )
    
    bt_train_matrix = [[0 for _ in range(task_config.num_tasks)] for _ in range(task_config.num_tasks)]
    bt_test_matrix =  [[0 for _ in range(task_config.num_tasks)] for _ in range(task_config.num_tasks)] 
    length_k = 0
    
    train_loader_list = []
    test_load_list = []
    for task_id in range(task_config.num_tasks):
        train_dataset_task, test_dataset_task, not_aug_train_dataset_task, reoccured = get_task_dataset(task_id)

        g = torch.Generator()
        g.manual_seed(cfg.seed)

        if not isinstance(train_dataset_task, DataLoader):
            train_loader = DataLoader(train_dataset_task, batch_size=task_config.batch_size, shuffle=True, generator=g)
            aux_train_loader = DataLoader(train_dataset_task, batch_size=task_config.batch_size, shuffle=True, generator=g)
        
        if not isinstance(test_dataset_task, DataLoader):
            test_loader = DataLoader(test_dataset_task, batch_size=task_config.batch_size, shuffle=False)
            aux_test_loader = DataLoader(test_dataset_task, batch_size=task_config.batch_size, shuffle=False)
        
        train_loader_list.append(train_loader) 
        test_load_list.append(test_loader)

        if isinstance(model, HATAgent):
            model.current_task_id = task_id
            
        zero_shot_train_acc, zero_shot_test_acc = evaluate_one_task(model, device, train_loader, test_loader)

        few_shot_train_acc, few_shot_test_acc, final_train_acc, final_test_acc, train_result, test_result, metrics, global_counter = train_one_task(task_id, 
                                                                                                                                                    model,
                                                                                                                                                    cfg, agent_config, arch_config, task_config, device,
                                                                                                                                                    train_loader, 
                                                                                                                                                    test_loader,
                                                                                                                                                    epochs=task_config.epochs, 
                                                                                                                                                    reoccured= reoccured,
                                                                                                                                                    global_counter=global_counter, 
                                                                                                                                                    eval_interval=cfg.train_eval_interval,
                                                                                                                                                    wandb_enable=cfg.wandb,
                                                                                                                                                    aux_train_loader=aux_train_loader,
                                                                                                                                                    aux_test_loader=aux_test_loader)

        if agent_config.reset_network:
            model = create_model(cfg, agent_config, arch_config, task_config, device)
            model.to(device)

        if not cfg.agent.agent_type == "CBPAgent":
            model.get_optimizer()

        # Get completed task's train data.
        if isinstance(train_dataset_task, DataLoader):
            xs_list = []
            ys_list = []
            for batch in train_dataset_task:
                x, y = batch
                xs_list.append(x)
                ys_list.append(y)

            task_test_xs = torch.cat(xs_list, dim=0)   # shape: (N, ...)
            task_test_ys = torch.cat(ys_list, dim=0)   # shape: (N, ...)
        else:
            loader = DataLoader(train_dataset_task, batch_size=len(train_dataset_task), shuffle=False)
            #task_test_xs, _, task_test_ys = next(iter(loader))
            task_test_xs, task_test_ys = next(iter(loader))

        if cfg.forgetting_mech:
            if cfg.er_type == 'agem':
                # Get completed not augmented task's train data.
                if isinstance(not_aug_train_dataset_task, DataLoader):
                    xs_list = []
                    ys_list = []
                    for batch in not_aug_train_dataset_task:
                        x, y = batch
                        xs_list.append(x)
                        ys_list.append(y)

                    not_aug_task_xs = torch.cat(xs_list, dim=0)   # shape: (N, ...)
                    not_aug_task_ys = torch.cat(ys_list, dim=0)   # shape: (N, ...)
                else:
                    loader = DataLoader(not_aug_train_dataset_task, batch_size=len(train_dataset_task), shuffle=False)
                    #not_aug_task_xs, _, not_aug_task_ys = next(iter(loader))
                    not_aug_task_xs, not_aug_task_ys = next(iter(loader))
                
                model.end_task(not_aug_task_xs, not_aug_task_ys)
        
        if cfg.agent.agent_type in ['EWCAgent', 'L2InitPlusEWCAgent']:

            # Shuffle data.
            dataset_len = len(task_test_xs)
            indices = np.arange(dataset_len)
            np.random.shuffle(indices)
            task_test_xs = task_test_xs[indices][:int(dataset_len/10)]
            task_test_ys = task_test_ys[indices][:int(dataset_len/10)]

            # Update Fisher matrix using this data.
            model.update_params_and_fisher(
                task_test_xs, task_test_ys, batch_size=task_config.batch_size)
        
        log_info = {}
        
        # if cfg.agent.agent_type == "NeuroSyncAgent":
        #     log_info_neuro_sync = model.model.plot_params()
        #     log_info = log_info_neuro_sync
        # TODO we made this change to prevent memory leakage
        train_loader_list = []
        test_load_list = []
        if cfg.monitor_backward_transfer:
            for i_task in range(task_id):
                if task_id == task_config.num_tasks - 1:
                    pass
                else:
                    if i_task < task_id - 1 :
                        continue

                g = torch.Generator()
                g.manual_seed(cfg.seed)
                train_loader = train_loader_list[i_task]
                test_loader = test_load_list[i_task]
                
                aux_model = create_model(cfg, agent_config, arch_config, task_config, device)
                aux_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    aux_model.load_consolidated_weights(model.model._ema_params)
                
                if not cfg.agent.agent_type == "CBPAgent":
                    model.get_optimizer()
                
                print(f'Backward adaptation task {task_id} to task {i_task}')
                train_reslut, test_result = BWA_one_task(aux_model, 
                                                         cfg, agent_config, arch_config, task_config, device,
                                                         train_loader, 
                                                         test_loader, 
                                                         cfg.num_epoch_backward_adaptation, 
                                                         eval_interval=cfg.bwa_eval_interval)

                length_k = len(train_reslut)
                print(f'Backward adaptation task {task_id} to task {i_task} len: {length_k}')
                bt_train_matrix[task_id][i_task] = train_reslut 
                bt_test_matrix[task_id][i_task] = test_result 
        
        # Assess forgetting 
        ave_forgetting = 0
        ave_test_forgetting = 0
        if cfg.forgetting_mech or cfg.agent.agent_type in ['EWCAgent', 'L2InitPlusEWCAgent', 'HATAgent']: 

            for i_task in range(task_id+1):
                g = torch.Generator()
                g.manual_seed(cfg.seed)
                train_loader = train_loader_list[i_task]
                test_loader = test_load_list[i_task]
                
                aux_model = create_model(cfg, agent_config, arch_config, task_config, device)
                aux_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent": # type: ignore
                    aux_model.load_consolidated_weights(model.model._ema_params)
                
                if not cfg.agent.agent_type == "CBPAgent": # type: ignore
                    model.get_optimizer()
                
                if isinstance(model, HATAgent):
                    model.current_task_id = i_task
                
                aux_model.to(device)

                train_reslut, test_result = evaluate_one_task(aux_model, device=device, aux_train_loader=train_loader, aux_test_loader=test_loader, ema_update=False)
                ave_forgetting += train_reslut
                ave_test_forgetting += test_result
                print(f'Train Accuracy for Taks {i_task} after Trained on Taks {task_id}: {train_reslut}')
                
                bt_train_matrix[task_id][i_task] = train_reslut  # type: ignore
                bt_test_matrix[task_id][i_task] = test_result # type: ignore
            print(f'Average Accuracy on Task {task_id}: {ave_forgetting/(task_id+1)}')

        if reoccured:
            log_info[f'reoccuring/train_acc_{task_config.benchmark}'] = final_train_acc
            log_info[f'reoccuring/test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'reoccuring/few_epoch_train_acc_{task_config.benchmark}'] = few_shot_train_acc
            log_info[f'reoccuring/few_epoch_test_acc_{task_config.benchmark}'] = few_shot_test_acc
        
        if cfg.wandb:
            #breakpoint()
            log_info[f"final_train_acc_{task_config.benchmark}"] = final_train_acc
            log_info[f'final_test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'zero_shot_train_acc_{task_config.benchmark}'] = zero_shot_train_acc
            log_info[f'zero_shot_test_acc_{task_config.benchmark}'] = zero_shot_test_acc
            log_info[f'average_performance_forgetting_{task_config.benchmark}'] = ave_forgetting/(task_id+1)
            log_info[f'average_test_performance_forgetting_{task_config.benchmark}'] = ave_test_forgetting/(task_id+1)
            activation_statistics = {} #model.compute_activation_statistics(task_test_xs.to(device=device)) TODO 
            log_info = {**log_info, **activation_statistics}
            
            if global_counter == 0:
                wandb.log(log_info)
            else:
                wandb.log(log_info, step=global_counter)


        print(f" Task {task_id+1}/{task_config.num_tasks} -> test acc: {final_train_acc:.3f}")

    
    save_artifact(bt_train_matrix, f'{task_config.num_tasks}_{cfg.agent.agent_type}_train_backward', cfg.wandb, task_config.num_tasks, length_k)
    save_artifact(bt_test_matrix, f'{task_config.num_tasks}_{cfg.agent.agent_type}_test_backward', cfg.wandb, task_config.num_tasks, length_k)

if __name__ == "__main__":
    main()




