import unittest
import hydra
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf
from torch import multiprocessing
import os
import random
from collections import defaultdict
import time
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="pandas")
import time
import traceback
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from functools import partial
from enum import Enum
import atexit
import click
from datetime import datetime
import os
import requests
import sys
import yaml
import json
from functools import partial
from collections import deque
import math
from copy import deepcopy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_CACHE = "Public-Health-Agent2/data/all_states_mort_inf_cumulative.pt"

save_dir = "Public-Health-Agent2/code_dump_dir"
os.makedirs(save_dir, exist_ok=True)

ALL_SCENARIOS = ["A", "B", "C", "D", "E", "F"]

def flatten(lst):
    for item in lst:
        if isinstance(item, list):
            yield from flatten(item)
        else:
            yield item

def get_model_parameters(model):
    param_dict = {}
    for name, param in model.named_parameters():
        if param.numel() == 1:
            param_dict[name] = param.item()                  # scalar OK
        else:
            param_dict[name] = param.detach().cpu().tolist() # convert vector
    return param_dict

def get_negative_parameters(param_dict):
    negative_params = {}
    for name, value in param_dict.items():
        if isinstance(value, (int, float)):
            if value < 0:
                negative_params[name] = value
        elif isinstance(value, list):
            # Flatten nested lists if needed
            flat_values = flatten(value)
            negatives = [v for v in flat_values if v < 0]
            if negatives:
                negative_params[name] = negatives
    return negative_params





class CovidSEnv:
    def __init__(self):
        pass

    def reset(self, num_patients=1):
        pass
    
    def evaluate_simulator_code_wrapper(self, StateDifferential, train_data, val_data, test_data, config={}, logger=None, env_name=''):
        if config.run.optimizer == 'pytorch':
            train_loss, val_loss, optimized_parameters, loss_per_dim, test_loss, sc_output = self.evaluate_simulator_code_using_pytorch(StateDifferential, train_data, val_data, test_data, config=config, logger=logger, env_name=env_name)
        if env_name == 'Covid-scenario':
            print(loss_per_dim)
            loss_per_dim_dict =  {'infected': loss_per_dim}
        return train_loss, val_loss, optimized_parameters, loss_per_dim_dict, test_loss, sc_output
    
    def evaluate_simulator_code_using_pytorch(self, StateDifferential, train_data, val_data, test_data, config={}, logger=None, env_name=''):
        import torch
        import numpy as np
        config.run.pytorch_as_optimizer.batch_size = 64
        
        y_train_M, y_train_I, T_train, population_train = train_data

        horizon = 40
        f_model = StateDifferential(T_train)
        f_model.to(device)

        f_model.train()
            
        MSE = torch.nn.MSELoss()
        # optimizer = optim.Adam(f_model.parameters(), lr=1e-2, weight_decay=0.01)
        optimizer = torch.optim.Adam(
            f_model.parameters(),
        	lr=5e-3,          # IMPORTANT: higher than default
        	weight_decay=1e-4
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        	optimizer,
        	mode="min",
        	factor=0.5,
        	patience=20,
        )


        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        def train(model, y_M, y_I, population):
            torch.autograd.set_detect_anomaly(True)
            optimizer.zero_grad(True)
            pred_M, pred_I = model()
            horizon = len(y_M[0])
            pred_M, pred_I = pred_M[:,:horizon], pred_I[:,:horizon]
            # loss = (population.unsqueeze(1) * (MSE(pred_M[:,:horizon],y_M) + MSE(pred_I[:,:horizon],y_I)  )).mean()

            
            d_pred_I = pred_I[:, 1:] - pred_I[:, :-1]
            d_true_I = y_I[:, 1:] - y_I[:, :-1]
            
            d_pred_M = pred_M[:, 1:] - pred_M[:, :-1]
            d_true_M = y_M[:, 1:] - y_M[:, :-1]
            
            loss_I_inc = (d_pred_I - d_true_I) ** 2        # (states, T-1)
            loss_M_inc = (d_pred_M - d_true_M) ** 2
            
            loss_I_lvl = (pred_I - y_I) ** 2               # (states, T)
            loss_M_lvl = ((pred_M - y_M) ** 2)
        
            pop_w = population.unsqueeze(1)
        
            loss_per_state = (
                (pop_w * (loss_I_inc + loss_M_inc)).mean(dim=1) +
                (pop_w * (loss_I_lvl + loss_M_lvl)).mean(dim=1)
        	)  # shape: (states,)
            loss = loss_per_state.mean()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=50.0)
            
            optimizer.step()
            scheduler.step(loss)
            return loss.item()

        train_opt = train

            

        def compute_eval_loss_perSc(model, dataset):
            y_M, y_I, T, population = dataset
            model.eval()
            horizon = len(y_M[0])
            with torch.no_grad():
                pred_M, pred_I = model()
                val_loss = (
                    population.unsqueeze(1) * (
                        MSE(pred_M[:, :horizon], y_M) +
                        MSE(pred_I[:, :horizon], y_I)
                    )
                ).mean().item()
                
                loss_per_dim = (
                    population.unsqueeze(1) *
                    torch.square(pred_I[:, :horizon] - y_I)
                ).mean().item()
        
            model.train()
            return val_loss, loss_per_dim

        def compute_eval_loss(model, dataset):
            """
            Returns:
                mean_loss: scalar
                mean_loss_per_dim: scalar
                scenario_outputs: dict[scenario_id -> summary_dict]
            """
            y_M, y_I, T, population = dataset
            model.eval()
            horizon = y_M.shape[1]
        
            scenario_outputs = {}
            total_loss = 0.0
            total_loss_per_dim = 0.0
        
            with torch.no_grad():
                pred_M, pred_I = model()  # [patches, T]
    
                loss = (
                    population.unsqueeze(1) *
                    (MSE(pred_M[:, :horizon], y_M) +
                     MSE(pred_I[:, :horizon], y_I))
                ).mean()
    
                loss_per_dim = (
                    population.unsqueeze(1) *
                    torch.square(pred_I[:, :horizon] - y_I)
                ).mean()
                
            model.train()
 
            return loss.item(), loss_per_dim.item()


        


        def compute_test_loss(model, dataset):
            y_M, y_I, T, population = dataset
            model.eval()
            horizon = len(y_M[0])
            
            total_loss, total_loss_per_dim = 0,0
            scenario_outputs = {}
            
            with torch.no_grad():
                for scenario_id in ALL_SCENARIOS:
                    model.scenario = scenario_id
        
                    pred_M, pred_I = model()  # [patches, T]
        
                
                    loss = (
                        population.unsqueeze(1) *
                        (MSE(pred_M[:, :horizon], y_M) +
                         MSE(pred_I[:, :horizon], y_I))
                    ).mean()
        
                    loss_per_dim = (
                        population.unsqueeze(1) *
                        torch.square(pred_I[:, :horizon] - y_I)
                    ).mean()
        
                    total_loss += loss.item()
                    total_loss_per_dim += loss_per_dim.item()
        
                    
                    I_tot = pred_I.sum(dim=0)  # [T]
                    M_tot = pred_M.sum(dim=0)
        
                    T_eff = I_tot.shape[0]
                    early_w = min(6, T_eff - 1)
                    late_w = min(6, T_eff - 1)
                    optimized_parameters = get_model_parameters(model)
                    try:
                        negative_parameters = get_negative_parameters(optimized_parameters)
                    except:
                        negative_parameters = 0
                        print('Couldnt set any neg params')
        
            
                    scenario_outputs[scenario_id] = {
                        "loss": loss.item(),
                        "epidemic_dynamics": {
                            "peak_infections": float(I_tot.max()),
                            "peak_week": int(I_tot.argmax()),
                            "early_growth_rate": float(
                                (I_tot[early_w] - I_tot[0]) / early_w
                            ),
                            "late_growth_rate": float(
                                (I_tot[-1] - I_tot[-late_w]) / late_w
                            ),
                        },
                        "burden": {
                            "final_infections": float(I_tot[-1]),
                            "final_deaths": float(M_tot[-1]),
                        },
                        "negative_parameters": negative_parameters,
                    }
                
                    
                val_loss = total_loss / len(ALL_SCENARIOS)
                loss_per_dim = total_loss_per_dim / len(ALL_SCENARIOS)
            
            timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
            out_file = f"pred_{timestamp}.npz"
            np.savez_compressed(
                out_file,
                pred_M=pred_M.detach().cpu().numpy(),
                pred_I=pred_I.detach().cpu().numpy(),
                y_true_M=y_M.detach().cpu().numpy(),
                y_true_I=y_I.detach().cpu().numpy(),
            )
            
            model_path = os.path.join(save_dir, f"StateDifferential_TEST_{timestamp}.pt")
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "T": T,
                    "horizon": horizon,
                    "parameters": get_model_parameters(model)
                },
                model_path,
            )


            return val_loss, loss_per_dim, scenario_outputs

        
        
                
        best_model = None
        if config.run.optimize_params:
            best_val_loss = float('inf')  # Initialize with a very high value
            patience_counter = 0  # Counter for tracking patience
            epochs = config.run.pytorch_as_optimizer.epochs
        
            for epoch in range(epochs):
                f_model.train()
                cum_loss = 0.0
                iters = 1
                t0 = time.perf_counter()
                
                batch_loss = train_opt(f_model, y_train_M, y_train_I, population_train)
                cum_loss = batch_loss

                time_taken = time.perf_counter() - t0
        
                # --- Validation phase ---
                if epoch % config.run.pytorch_as_optimizer.log_interval == 0:
                    val_loss, _ = compute_eval_loss(f_model, val_data)
                    print(f"[EPOCH {epoch:03d}] TRAIN LOSS={cum_loss/1.0:.6f} | VAL LOSS={val_loss:.6f} | time={time_taken:.2f}s")
        
                    # Early stopping
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        best_model = deepcopy(f_model.state_dict())
                        patience_counter = 0
                    else:
                        patience_counter += 1
        
                    if patience_counter >= config.run.optimization.patience:
                        if logger:
                            logger.info(f"Early stopping triggered at epoch {epoch}")
                        print(f"Early stopping triggered at epoch {epoch}")
                        break
        
        else:
            cum_loss, iters = 1.0, 1

        # Save model after training
        f_model.eval()
        if best_model is not None:
            f_model.load_state_dict(best_model)
            print('Loaded best model')
            
        val_loss, _ = compute_eval_loss(f_model, val_data)
        print(f'[Train Run completed successfully] MSE VAL LOSS {val_loss:.4f}')
        print('')
        optimized_parameters = get_model_parameters(f_model)
        test_loss, loss_per_dim , sc_output = compute_test_loss(f_model, test_data)
        
        return cum_loss, val_loss, optimized_parameters, loss_per_dim, test_loss, sc_output



def run_simulation_mort_I(state_name = "Florida"):
    import numpy as np
    df = pd.read_csv("Public-Health-Agent2/data/deaths_cumulative_num.csv")
    df_state = df[df["geo_value_fullname"] == state_name].copy()
    df_state["time_value"] = pd.to_datetime(df_state["time_value"])
    df_state = df_state.sort_values("time_value")
    
    weekly_series = (
        df_state
        .set_index("time_value")["value"]
        .astype(float)
    )

    temp = pd.read_csv("Public-Health-Agent2/data/time_series_covid19_deaths_US.csv")
    state_df = temp[temp["Province_State"] == state_name]

    population = state_df["Population"].sum() 

    y_M = torch.tensor(weekly_series / population, dtype=torch.float32).to(device)
    
    df = pd.read_csv("Public-Health-Agent2/data/confirmed_cumulative_num.csv")
    df_state = df[df["geo_value_fullname"] == state_name].copy()
    df_state["time_value"] = pd.to_datetime(df_state["time_value"])
    df_state = df_state.sort_values("time_value")
    
    weekly_series = (
        df_state
        .set_index("time_value")["value"]
        .astype(float)
    )
    
    y_I = torch.tensor(weekly_series / population, dtype=torch.float32).to(device)
    T = len(y_I)
    print(len(y_M), len(y_I))
    
    return y_M, y_I, T, population


def load_all_states():

    try:
        cache = torch.load(DATA_CACHE, map_location=device)

        train_set = cache["train_set"]
        val_set   = cache["val_set"]
        test_set  = cache["test_set"]
        states    = cache["states"]
        description = cache["description"]
    
        # print(f"[Loaded] Cached dataset from {cache_path}")
        print(f"States: {len(states)} | T = {train_set[2]}")
    
        return train_set, val_set, test_set, description
    except:
        pass
    description = """
"""

    df = pd.read_csv(
        "Public-Health-Agent2/data/time_series_covid19_deaths_US.csv"
    )
    states = sorted(df["Province_State"].unique())
    print(states)

    Y_M_list = []
    Y_I_list = []
    P_list   = []
    T_list   = []
    valid_states = []

    for state in states:
        try:
            y_M, y_I, T, pop = run_simulation_mort_I(state)
        except Exception as e:
            print(f"Skipping {state}: {e}")
            continue

        if (
            torch.isnan(y_M).any() or torch.isinf(y_M).any() or
            torch.isnan(y_I).any() or torch.isinf(y_I).any()
        ):
            print(f"Skipping {state}: NaN or Inf detected")
            continue
            
        if T:
            Y_M_list.append(y_M)
            Y_I_list.append(y_I)
            P_list.append(pop)   
            T_list.append(T)
            valid_states.append(state)

    P_list = torch.tensor(P_list, dtype=torch.float32, device=device)
    
    T_min = min(T_list)

    Y_M = torch.stack([y[:T_min] for y in Y_M_list], dim=0)
    Y_I = torch.stack([y[:T_min] for y in Y_I_list], dim=0)

    num_states = len(valid_states)
    idx = np.arange(num_states)

    # 70 / 15 / 15 split
    np.random.seed(42)
    np.random.shuffle(idx)

    # n_train = int(num_states)
    # n_val   = int(0.15 * num_states)

    train_idx = idx[:]
    # val_idx   = idx[n_train:n_train + n_val]
    # test_idx  = idx[n_train + n_val:]

    train_set = (
        Y_M[train_idx],
        Y_I[train_idx],
        T_min,
        P_list[train_idx],
    )

    val_set = train_set 
    test_set = val_set

    cache = {
        "train_set": train_set,
        "val_set": val_set,
        "test_set": test_set,
        "states": valid_states,
        "description": description,
    }

    torch.save(cache, DATA_CACHE)
    print(f"[Saved] Cached dataset to {DATA_CACHE}")
    

    return train_set, val_set, test_set, description


def load_data_inner():
    description = """
TODO
"""
    train_set = run_simulation_mort_I('Florida')
    val_set = run_simulation_mort_I('Florida')
    test_set = run_simulation_mort_I('Florida')

    return train_set, val_set, test_set, description


def load_data():
    data = load_all_states()
    return data