

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import random
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from datetime import datetime
from tasks.task_generator import configure_dataset
from utils.utils import *
import wandb
import os 
from agent import *
from collections import deque
import hydra
from omegaconf import DictConfig
import matplotlib.pyplot as plt
from copy import deepcopy
from contextlib import contextmanager
import random, torch, numpy as np
#from hat import HATPayload


def save_artifact(matrix, name, wandb_enable, num_tasks, length):
    # 1) Fill non‐list cells with zero‐lists of length k

    os.makedirs('./forgetting_data/', exist_ok=True)
    
    np.save(f"./forgetting_data/{name}.npy", matrix)

    if wandb_enable:
        artifact = wandb.Artifact(name, type="dataset")
        artifact.add_file(f"./forgetting_data/{name}.npy")

        wandb.log_artifact(artifact)


def evaluate_one_task(model, device, aux_train_loader, aux_test_loader, ema_update=True):

    new_train_loader = aux_train_loader
    new_test_loader = aux_test_loader

    with torch.no_grad():
        # Evaluate on test set
        if not ema_update:
            model.eval()
        train_correct = 0
        train_total = 0
        with torch.no_grad():
            #for data, _, target in new_train_loader:
            for data, target in new_train_loader:
                data, target = data.to(device), target.to(device)
            
                logits = model.predict(data)

                preds = logits.argmax(dim=1)
                train_correct += preds.eq(target).sum().item()
                train_total += data.size(0)
        if not ema_update:
            model.eval()
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            #for data, _, target in new_test_loader:
            for data, target in new_test_loader:
                data, target = data.to(device), target.to(device)
                logits = model.predict(data)
                preds = logits.argmax(dim=1)
                test_correct += preds.eq(target).sum().item()
                test_total += data.size(0)
    model.train()

    return train_correct / train_total, test_correct / test_total



def BWA_one_task(model, cfg, agent_config, arch_config, task_config, device, train_loader, test_loader, epochs, eval_interval):
    """
    Train a given model on one random-label CIFAR10 task for 'epochs' epochs,
    then return the final test accuracy.
    """
    model.to(device)
    model.train()
    batch_counter = 0
    train_reslut = []
    test_result = []
    
    for ep in range(epochs):
        #for data, _, target in train_loader:
        for data, target in train_loader:
            model.to(device)
            model.train()
            data, target = data.to(device), target.to(device)
            logits, metrics = model.step(data, target)
            loss = metrics['curr_train_loss']
            
            if batch_counter % eval_interval == 0:
                eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                eval_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    eval_model.load_consolidated_weights(model.model._ema_params)
                eval_model.to(device)
                eval_model.eval()
                train_acc, test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=train_loader, aux_test_loader=test_loader)
                train_reslut.append(train_acc)
                test_result.append(test_acc)
            
            batch_counter += 1

    return train_reslut, test_result

def train_one_task(task_id, model, cfg, agent_config, arch_config, task_config, device, train_loader, test_loader, epochs, reoccured, global_counter, eval_interval, wandb_enable, aux_train_loader, aux_test_loader):
    """
    Train a given model on one random-label CIFAR10 task for 'epochs' epochs,
    then return the final test accuracy.
    """
    model.to(device)
    model.train()
    few_shot_train_acc = 0
    few_shot_test_acc = 0
    final_train_acc = 0.0
    final_test_acc = 0.0
    
    train_result = []
    test_result = []
    total_batches = int(task_config.limit / task_config.batch_size)
    with tqdm(range(epochs), desc=f'training the model task {task_id}', disable=False) as pbar:
        for ep in range(epochs):
            batch_counter = 0
            model.train()
            #for data, raw_image, target in tqdm(train_loader, desc= 'training the model in one epoch', disable=True):
            for data, target in tqdm(train_loader, desc= 'training the model in one epoch', disable=True):
                data, target = data.to(device), target.to(device)
                
                if isinstance(model, ERWrapper):
                    logits, metrics = model.step(data, target, raw_image=data)
                else:
                    logits, metrics = model.step(data, target)
                
                log_info = {**metrics}
                if isinstance(model, HATAgent):
                    model.current_mask_scale = (1/model.max_mask_scale) + (model.max_mask_scale - 1/model.max_mask_scale) * ((batch_counter-1)/(total_batches-1))
                if batch_counter % eval_interval == 0 and cfg.monitor_forward_transfer:
                    eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                    eval_model.load_state_dict(model.state_dict())
                    if cfg.agent.agent_type == "NeuroSyncAgent":
                        eval_model.load_consolidated_weights(model.model._ema_params) 
                    eval_model.to(device)
                    eval_model.eval()
                    train_acc, test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)
                    if wandb_enable:
                        log_info = {f'train_acc_{task_config.benchmark}':train_acc,
                                    f'test_acc_{task_config.benchmark}':test_acc, **log_info}
                    
                    train_result.append(train_acc)
                    test_result.append(test_acc)
                    
                global_counter += 1
                batch_counter += 1
                
                wandb.log(log_info, step=global_counter)
                
            
            if reoccured and ep == 1:
                eval_model = create_model(cfg, agent_config, arch_config, task_config, device)
                eval_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    eval_model.load_consolidated_weights(model.model._ema_params)
                eval_model.eval()
                eval_model.to(device)
                few_shot_train_acc, few_shot_test_acc = evaluate_one_task(eval_model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)

    final_train_acc, final_test_acc = evaluate_one_task(model, device=device, aux_train_loader=aux_train_loader, aux_test_loader=aux_test_loader)

    return few_shot_train_acc, few_shot_test_acc, final_train_acc, final_test_acc, train_result, test_result, metrics, global_counter


def create_model(cfg, agent_config, arch_config, task_config, device):
    # Create a new model
    if cfg.agent.agent_type == "BaseAgent":
        model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        
    elif cfg.agent.agent_type == "L2Agent":
        model = L2Agent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "ReDoAgent":
        model = ReDoAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "L2InitAgent":
        model = L2InitAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "LayerNormAgent":
        model = LayerNormAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)

    elif cfg.agent.agent_type == "HATAgent":
        model = HATAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)

    elif cfg.agent.agent_type == "HareTortoiseAgent":
        agent_config.reset_to_ema = 10 * task_config.limit // task_config.batch_size
        model = HareTortoiseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        if agent_config.warmup:
            #task_config.limit = int(task_config.limit * agent_config.portion)
            task_config.epochs = agent_config.warm_epoch
            task_config.num_tasks = 1
        else:
            # model/h_t/model.pt
            ckpt_path = f"/localhome/srr8/project/ICRL_17/17-before_forgetting/model/h_t_{task_config.benchmark}/model.pt"
            # Use robust loader that accepts dict or bare SD
            load_hare_tortoise(model, ckpt_path, map_location=device)
    
    elif cfg.agent.agent_type == "NeuroSyncAgent":
        if cfg.agent.use_ema_target:
            num_foward_pass_each_task = task_config.epochs * (task_config.limit / task_config.batch_size)
            agent_config.ema_decay = float(np.power(task_config.ema_target, 1/num_foward_pass_each_task))
        #print('ema_decay:', agent_config.ema_decay)
        model = NeuroSyncAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "DeepFourierAgent":
        model = DeepFourierAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CReLUAgent":
        model = CReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "PReLUAgent":
        model = PReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CBPAgent":
        model = CBPAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "EWCAgent":
        model = EWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "L2InitPlusEWCAgent":
        model = L2InitPlusEWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "ShrinkAndPerturbAgent":
        model = ShrinkAndPerturbAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    elif cfg.agent.agent_type == "ViTAgent":
        model = ViTAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    else:
        raise ValueError(f"Unknown agent: {cfg.agent}")
    
    if cfg.forgetting_mech:
        assert cfg.agent.agent_type not in ['EWCAgent', 'L2InitPlusEWCAgent', 'HATAgent'] 
        if cfg.er_type == 'er':
            model = ERWrapper(model, buffer_size=cfg.buffer_size, 
                              minibatch_size=int(cfg.buffer_batch_size_ratio * task_config.batch_size),
                              device=device)
        elif cfg.er_type == 'agem':
            model = AGemWrapper(model, buffer_size=cfg.buffer_size, 
                              minibatch_size=int(cfg.buffer_batch_size_ratio * task_config.batch_size),
                              num_tasks=task_config.num_tasks,
                              device=device)
        else:
            print(f'{cfg.er_type} is not supported')
    
    return model


from pathlib import Path

def save_hare_tortoise(agent, path):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    payload = {
        "model": agent.model.state_dict(),
    }
    if hasattr(agent, "ema_model"):
        payload["ema_model"] = agent.ema_model.state_dict()
    torch.save(payload, path)

def load_hare_tortoise(agent, path, map_location="cpu"):
    ckpt = torch.load(path, map_location=map_location)
    # Accept either a dict with "model" or a bare state_dict
    model_sd = ckpt.get("model", ckpt)
    agent.model.load_state_dict(model_sd)

    # If ema weights exist, load them; otherwise mirror the model weights
    if hasattr(agent, "ema_model"):
        ema_sd = ckpt.get("model", None)
        if ema_sd is None:
            ema_sd = model_sd
        agent.ema_model.load_state_dict(ema_sd)

# ------------------------------------------------------------------------
# 9) Main Experiment
# ------------------------------------------------------------------------
@hydra.main(config_path='configs/sl', config_name='config')
def main(cfg: DictConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    set_seed(cfg.seed)

    agent_config = cfg.agent
    arch_config = cfg.arch
    task_config = cfg.task

    get_task_dataset, _, _ = configure_dataset(task_config=task_config, arch_config=arch_config, use_hat=agent_config.agent_type=='HATAgent', args=cfg)

    
    model = create_model(cfg, agent_config, arch_config, task_config, device)

    if not cfg.agent.agent_type == "CBPAgent":
        model.get_optimizer()

    print(f"\n===== Training Model: {cfg.agent} =====")

    
    if cfg.agent == "NeuroSyncAgent":
        orchs_param_count, network_param_count = model.model.compute_total_params()
        print('Total number of trainable parameters:', orchs_param_count + network_param_count)
        print('Total number of trainable parameters in the network:', network_param_count)
        print('Percentage of trainable parameters in the network:', network_param_count / (orchs_param_count + network_param_count) * 100)
    else:
        print('Total number of trainable parameters:', model.model.compute_total_params())

    model.to(device)
    global_counter = 0
    if cfg.forgetting_mech:
        exp_name = f'{model.agent.__class__.__name__}_{cfg.er_type}_{arch_config.arch_name}_{agent_config.optimizer}_{agent_config.lr}_{task_config.benchmark}_transform={task_config.tranform}_{cfg.seed}'
    else:
        exp_name = f'{model.__class__.__name__}_{arch_config.arch_name}_{agent_config.optimizer}_{agent_config.lr}_{task_config.benchmark}_transform={task_config.tranform}_{cfg.seed}'
    
    if agent_config.reset_network:
        exp_name = 'Scratch_' + exp_name
        
    if cfg.wandb:
        merged_config = {}
        merged_config.update(namespace_to_dict(task_config))
        merged_config.update(namespace_to_dict(agent_config))
        merged_config.update(namespace_to_dict(arch_config))
        merged_config['seed'] = cfg.seed  # keep args fields if you like
        wandb.init(
            project=cfg.proj_name,
            name= exp_name,
            group= task_config.benchmark,
            config=merged_config,
            save_code=True,
        )
    
    bt_train_matrix = [[0 for _ in range(task_config.num_tasks)] for _ in range(task_config.num_tasks)]
    bt_test_matrix =  [[0 for _ in range(task_config.num_tasks)] for _ in range(task_config.num_tasks)] 
    length_k = 0
    
    train_loader_list = []
    test_load_list = []
    for task_id in range(task_config.num_tasks):
        train_dataset_task, test_dataset_task, not_aug_train_dataset_task, reoccured = get_task_dataset(task_id)

        g = torch.Generator()
        g.manual_seed(cfg.seed)

        if not isinstance(train_dataset_task, DataLoader):
            train_loader = DataLoader(train_dataset_task, batch_size=task_config.batch_size, shuffle=True, generator=g)
            aux_train_loader = DataLoader(train_dataset_task, batch_size=task_config.batch_size, shuffle=True, generator=g)
        
        if not isinstance(test_dataset_task, DataLoader):
            test_loader = DataLoader(test_dataset_task, batch_size=task_config.batch_size, shuffle=False)
            aux_test_loader = DataLoader(test_dataset_task, batch_size=task_config.batch_size, shuffle=False)
        
        train_loader_list.append(train_loader) 
        test_load_list.append(test_loader)

        if isinstance(model, HATAgent):
            model.current_task_id = task_id
            
        zero_shot_train_acc, zero_shot_test_acc = evaluate_one_task(model, device, train_loader, test_loader)

        few_shot_train_acc, few_shot_test_acc, final_train_acc, final_test_acc, train_result, test_result, metrics, global_counter = train_one_task(task_id, 
                                                                                                                                                    model,
                                                                                                                                                    cfg, agent_config, arch_config, task_config, device,
                                                                                                                                                    train_loader, 
                                                                                                                                                    test_loader,
                                                                                                                                                    epochs=task_config.epochs, 
                                                                                                                                                    reoccured= reoccured,
                                                                                                                                                    global_counter=global_counter, 
                                                                                                                                                    eval_interval=cfg.train_eval_interval,
                                                                                                                                                    wandb_enable=cfg.wandb,
                                                                                                                                                    aux_train_loader=aux_train_loader,
                                                                                                                                                    aux_test_loader=aux_test_loader)

        if agent_config.reset_network:
            model = create_model(cfg, agent_config, arch_config, task_config, device)
            model.to(device)

        if not cfg.agent.agent_type == "CBPAgent":
            model.get_optimizer()

        # Get completed task's train data.
        if isinstance(train_dataset_task, DataLoader):
            xs_list = []
            ys_list = []
            for batch in train_dataset_task:
                x, y = batch
                xs_list.append(x)
                ys_list.append(y)

            task_test_xs = torch.cat(xs_list, dim=0)   # shape: (N, ...)
            task_test_ys = torch.cat(ys_list, dim=0)   # shape: (N, ...)
        else:
            loader = DataLoader(train_dataset_task, batch_size=len(train_dataset_task), shuffle=False)
            #task_test_xs, _, task_test_ys = next(iter(loader))
            task_test_xs, task_test_ys = next(iter(loader))

        if cfg.forgetting_mech:
            if cfg.er_type == 'agem':
                # Get completed not augmented task's train data.
                if isinstance(not_aug_train_dataset_task, DataLoader):
                    xs_list = []
                    ys_list = []
                    for batch in not_aug_train_dataset_task:
                        x, y = batch
                        xs_list.append(x)
                        ys_list.append(y)

                    not_aug_task_xs = torch.cat(xs_list, dim=0)   # shape: (N, ...)
                    not_aug_task_ys = torch.cat(ys_list, dim=0)   # shape: (N, ...)
                else:
                    loader = DataLoader(not_aug_train_dataset_task, batch_size=len(train_dataset_task), shuffle=False)
                    #not_aug_task_xs, _, not_aug_task_ys = next(iter(loader))
                    not_aug_task_xs, not_aug_task_ys = next(iter(loader))
                
                model.end_task(not_aug_task_xs, not_aug_task_ys)
        
        if cfg.agent.agent_type in ['EWCAgent', 'L2InitPlusEWCAgent']:

            # Shuffle data.
            dataset_len = len(task_test_xs)
            indices = np.arange(dataset_len)
            np.random.shuffle(indices)
            task_test_xs = task_test_xs[indices][:int(dataset_len/10)]
            task_test_ys = task_test_ys[indices][:int(dataset_len/10)]

            # Update Fisher matrix using this data.
            model.update_params_and_fisher(
                task_test_xs, task_test_ys, batch_size=task_config.batch_size)
        
        log_info = {}
        
        # if cfg.agent.agent_type == "NeuroSyncAgent":
        #     log_info_neuro_sync = model.model.plot_params()
        #     log_info = log_info_neuro_sync
        
        
        if cfg.monitor_backward_transfer and False:
            for i_task in range(task_id):
                if task_id == task_config.num_tasks - 1:
                    pass
                else:
                    if i_task < task_id - 1 :
                        continue

                g = torch.Generator()
                g.manual_seed(cfg.seed)
                train_loader = train_loader_list[i_task]
                test_loader = test_load_list[i_task]
                
                aux_model = create_model(cfg, agent_config, arch_config, task_config, device)
                aux_model.load_state_dict(model.state_dict())
                if cfg.agent.agent_type == "NeuroSyncAgent":
                    aux_model.load_consolidated_weights(model.model._ema_params)
                
                if not cfg.agent.agent_type == "CBPAgent":
                    model.get_optimizer()
                
                print(f'Backward adaptation task {task_id} to task {i_task}')
                train_reslut, test_result = BWA_one_task(aux_model, 
                                                         cfg, agent_config, arch_config, task_config, device,
                                                         train_loader, 
                                                         test_loader, 
                                                         cfg.num_epoch_backward_adaptation, 
                                                         eval_interval=cfg.bwa_eval_interval)

                length_k = len(train_reslut)
                print(f'Backward adaptation task {task_id} to task {i_task} len: {length_k}')
                bt_train_matrix[task_id][i_task] = train_reslut 
                bt_test_matrix[task_id][i_task] = test_result 
        
        # Assess forgetting 
        ave_forgetting = 0
        ave_test_forgetting = 0
        for i_task in range(task_id+1):
            g = torch.Generator()
            g.manual_seed(cfg.seed)
            train_loader = train_loader_list[i_task]
            test_loader = test_load_list[i_task]
            
            aux_model = create_model(cfg, agent_config, arch_config, task_config, device)
            aux_model.load_state_dict(model.state_dict())
            if cfg.agent.agent_type == "NeuroSyncAgent": # type: ignore
                aux_model.load_consolidated_weights(model.model._ema_params)
            
            if not cfg.agent.agent_type == "CBPAgent": # type: ignore
                model.get_optimizer()
            
            if isinstance(model, HATAgent):
                model.current_task_id = i_task
            
            aux_model.to(device)

            train_reslut, test_result = evaluate_one_task(aux_model, device=device, aux_train_loader=train_loader, aux_test_loader=test_loader, ema_update=False)
            ave_forgetting += train_reslut
            ave_test_forgetting += test_result
            print(f'Train Accuracy for Taks {i_task} after Trained on Taks {task_id}: {train_reslut}')
            
            bt_train_matrix[task_id][i_task] = train_reslut  # type: ignore
            bt_test_matrix[task_id][i_task] = test_result # type: ignore
        print(f'Average Accuracy on Task {task_id}: {ave_forgetting/(task_id+1)}')
        
        if reoccured:
            log_info[f'reoccuring/train_acc_{task_config.benchmark}'] = final_train_acc
            log_info[f'reoccuring/test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'reoccuring/few_epoch_train_acc_{task_config.benchmark}'] = few_shot_train_acc
            log_info[f'reoccuring/few_epoch_test_acc_{task_config.benchmark}'] = few_shot_test_acc
        
        if cfg.wandb:
            #breakpoint()
            log_info[f"final_train_acc_{task_config.benchmark}"] = final_train_acc
            log_info[f'final_test_acc_{task_config.benchmark}'] = final_test_acc
            log_info[f'zero_shot_train_acc_{task_config.benchmark}'] = zero_shot_train_acc
            log_info[f'zero_shot_test_acc_{task_config.benchmark}'] = zero_shot_test_acc
            log_info[f'average_performance_forgetting_{task_config.benchmark}'] = ave_forgetting/(task_id+1)
            log_info[f'average_test_performance_forgetting_{task_config.benchmark}'] = ave_test_forgetting/(task_id+1)
            activation_statistics = model.compute_activation_statistics(task_test_xs.to(device=device))
            log_info = {**log_info, **activation_statistics}
            
            if global_counter == 0:
                wandb.log(log_info)
            else:
                wandb.log(log_info, step=global_counter)


        print(f" Task {task_id+1}/{task_config.num_tasks} -> test acc: {final_train_acc:.3f}")

        # TODO
        if cfg.agent.agent_type == 'NeuroSyncAgent':
            model.reset_neuro_sync()

    if isinstance(model, HareTortoiseAgent) and agent_config.warmup:
        save_path = f"/localhome/srr8/project/ICRL_17/17-before_forgetting/model/h_t_{task_config.benchmark}/model.pt"
        save_hare_tortoise(model, save_path)
        print("Model saved:", save_path)
    
    save_artifact(bt_train_matrix, f'{task_config.num_tasks}_{cfg.agent.agent_type}_train_backward', cfg.wandb, task_config.num_tasks, length_k)
    save_artifact(bt_test_matrix, f'{task_config.num_tasks}_{cfg.agent.agent_type}_test_backward', cfg.wandb, task_config.num_tasks, length_k)

if __name__ == "__main__":
    main()





