import torch
import torch.nn as nn
import time
from ptflops import get_model_complexity_info
from utils import *
from tasks.task_generator import configure_dataset
from utils.utils import *
from agent import *
import hydra
from omegaconf import DictConfig
import numpy as np
import warnings
warnings.filterwarnings("ignore")


def create_model(cfg, agent_config, arch_config, task_config, device):
    # Create a new model
    if cfg.agent.agent_type == "BaseAgent":
        model = BaseAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
        
    elif cfg.agent.agent_type == "L2Agent":
        model = L2Agent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "ReDoAgent":
        model = ReDoAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "L2InitAgent":
        model = L2InitAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "LayerNormAgent":
        model = LayerNormAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)

    elif cfg.agent.agent_type == "HATAgent":
        model = HATAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "NeuroSyncAgent":
        if cfg.agent.use_ema_target:
            num_foward_pass_each_task = task_config.epochs * (task_config.limit / task_config.batch_size)
            agent_config.ema_decay = float(np.power(task_config.ema_target, 1/num_foward_pass_each_task))
        #print('ema_decay:', agent_config.ema_decay)
        model = NeuroSyncAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "DeepFourierAgent":
        model = DeepFourierAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CReLUAgent":
        model = CReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "PReLUAgent":
        model = PReLUAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config)
    
    elif cfg.agent.agent_type == "CBPAgent":
        model = CBPAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)

    elif cfg.agent.agent_type == "EWCAgent":
        model = EWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "L2InitPlusEWCAgent":
        model = L2InitPlusEWCAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    
    elif cfg.agent.agent_type == "ShrinkAndPerturbAgent":
        model = ShrinkAndPerturbAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    elif cfg.agent.agent_type == "ViTAgent":
        model = ViTAgent(agent_config=agent_config, arch_config=arch_config, task_config=task_config, device=device)
    else:
        raise ValueError(f"Unknown agent: {cfg.agent}")
    
    if cfg.forgetting_mech:
        assert cfg.agent.agent_type not in ['EWCAgent', 'L2InitPlusEWCAgent', 'HATAgent'] 
        if cfg.er_type == 'er':
            model = ERWrapper(model, buffer_size=cfg.buffer_size, 
                              minibatch_size=int(cfg.buffer_batch_size_ratio * task_config.batch_size),
                              device=device)
        elif cfg.er_type == 'agem':
            model = AGemWrapper(model, buffer_size=cfg.buffer_size, 
                              minibatch_size=int(cfg.buffer_batch_size_ratio * task_config.batch_size),
                              num_tasks=task_config.num_tasks,
                              device=device)
        else:
            print(f'{cfg.er_type} is not supported')
    
    return model


@hydra.main(config_path='configs/sl', config_name='config', version_base=None)
def main(cfg: DictConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(cfg.seed)

    agent_config = cfg.agent
    arch_config = cfg.arch
    task_config = cfg.task

    model = create_model(cfg, agent_config, arch_config, task_config, device)
    model.to(device)
    model.eval()

    input_size = (3, 32, 32)

    # Get MACs & Params
    try:
        with torch.no_grad():
            macs, params = get_model_complexity_info(
                model, input_size, as_strings=True, print_per_layer_stat=False, verbose=False
            )
    except:
        macs, params = "N/A", "N/A"

    # Param size
    try:
        param_size_mb = sum(p.numel() for p in model.parameters() if p.requires_grad) * 4 / (1024 ** 2)
        param_size_str = f"{param_size_mb:.2f} MB"
    except:
        param_size_str = "N/A"

    # Measure Forward + Backward Time
    dummy_input = torch.randn(1, *input_size).to(device)
    try:
        num_classes = model(dummy_input).shape[1]
        dummy_target = torch.randint(0, num_classes, (1,)).to(device)
    except:
        dummy_target = torch.tensor([0]).to(device)

    criterion = nn.CrossEntropyLoss()

    
    torch.cuda.synchronize()
    start_time = time.time()
    try:
        output, _ = model(dummy_input)
    except:
        output = model(dummy_input)
    loss = criterion(output, dummy_target)
    loss.backward()
    torch.cuda.synchronize()
    end_time = time.time()
    total_time = f"{(end_time - start_time)* 1000:.1f}"


    # Final formatted row
    print(f"| {cfg.agent.agent_type:<20} | {macs:<10} | {params:<9} | {param_size_str:<10} | {total_time:<8} |")


if __name__ == "__main__":
    main()
