import json
import os
import sys
import traceback
import warnings
from copy import deepcopy
from datetime import datetime
from importlib.metadata import version
from typing import Any, Literal

import pandas as pd
import torch
from ase.constraints import FixSymmetry
from ase.filters import FrechetCellFilter
from ase.io import read
from ase.optimize import FIRE, LBFGS
from ase.optimize.optimize import Optimizer
from ase.spacegroup.symmetrize import check_symmetry
from phono3py.api_phono3py import Phono3py
from phonopy.structure.atoms import PhonopyAtoms
from pymatviz.enums import Key
from tqdm import tqdm

import matbench_discovery.phonons.thermal_conductivity as ltc
from matbench_discovery import phonons
from matbench_discovery.enums import DataFiles
from matbench_discovery.metrics.phonons import calc_kappa_metrics_from_dfs

warnings.filterwarnings("ignore", category=DeprecationWarning, module="spglib")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")

# Model configuration
from metatomic.torch.ase_calculator import MetatomicCalculator, O3AveragedCalculator
from metatomic.torch import load_atomistic_model

model_name = "pet"
model_variant = sys.argv[1]
precision = "float64"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float64 if precision == "float64" else torch.float32
model = load_atomistic_model(f"/work/cosmo/bigi/models/model-{model_variant}.pt")
model.capabilities().dtype = precision
model = model.to(dtype=dtype, device=device)
calc = MetatomicCalculator(model, device=device, non_conservative=False)

# Relaxation parameters
ase_optimizer: Literal["FIRE", "LBFGS"] = "FIRE"
max_steps = 300
force_max = 1e-4
symprec = 1e-5
displacement_distance = float(sys.argv[2])
enforce_relax_symm = True
ignore_broken_symm = False
ignore_imaginary_freqs = True
is_plusminus = True
temperatures = [300]  # Temperatures to calculate conductivity at in Kelvin
save_forces = False  # Save force sets to files
deterministic = False
precision = "float64"

task_type = "LTC"  # lattice thermal conductivity
job_name = f"{model_name}-phononDB-{task_type}-{ase_optimizer}_force{force_max}_sym{symprec}_ispm{is_plusminus}_disp{displacement_distance}"
out_dir = f"./orblike_kappa_results_{model_variant}"
os.makedirs(out_dir, exist_ok=True)
out_path = f"{out_dir}/{job_name}.json.gz"
force_sets_path = f"{out_dir}/force-sets.json.gz"

timestamp = f"{datetime.now().astimezone():%Y-%m-%d %H:%M:%S}"
atoms_list = read(DataFiles.phonondb_pbe_103_structures.path, index=":")

# Limit to only 10 structures
# atoms_list = atoms_list[:10]

run_params = {
    "timestamp": timestamp,
    "model_name": model_name,
    "model_variant": model_variant,
    "device": device,
    "versions": {dep: version(dep) for dep in ("numpy", "torch")},
    "ase_optimizer": ase_optimizer,
    "cell_filter": "FrechetCellFilter",
    "max_steps": max_steps,
    "force_max": force_max,
    "symprec": symprec,
    "enforce_relax_symm": enforce_relax_symm,
    "ignore_broken_symm": ignore_broken_symm,
    "ignore_imaginary_freqs": ignore_imaginary_freqs,
    "temperatures": temperatures,
    "displacement_distance": displacement_distance,
    "is_plusminus": is_plusminus,
    "task_type": task_type,
    "job_name": job_name,
    "n_structures": len(atoms_list),
    "precision": precision,
    "deterministic": deterministic,
}

with open(f"{out_dir}/run_params.json", mode="w") as file:
    json.dump(run_params, file, indent=4)

print(f"Results will be saved to {out_dir}")
print(f"Using {device=}")


if deterministic:
    torch.use_deterministic_algorithms(mode=True)


# Set up the optimizer class from string
optim_cls: type[Optimizer] = {"FIRE": FIRE, "LBFGS": LBFGS}[ase_optimizer]

force_results: dict[str, dict[str, Any]] = {}
kappa_results: dict[str, dict[str, Any]] = {}
prog_bar = True  # Enable progress bar

tqdm_bar = tqdm(
    enumerate(atoms_list), desc="Conductivity calculation: ", disable=not prog_bar
)

for idx, atoms in tqdm_bar:
    # Use the same ID field as in original script
    mat_id = atoms.info.get(Key.mat_id, f"id-{len(kappa_results)}")
    init_info = deepcopy(atoms.info)
    formula = atoms.info.get("name", "unknown")

    tqdm_bar.set_postfix_str(mat_id, refresh=True)

    # Initialize info dictionary with material details
    info_dict = {
        "name": formula,
        "errors": [],
        "error_traceback": [],
    }

    # Initialize variables that might be needed in error handling
    relax_dict = {"max_stress": None, "reached_max_steps": False}
    force_results_item = None

    try:
        # Relaxation phase
        atoms.calc = calc
        if enforce_relax_symm:
            atoms.set_constraint(FixSymmetry(atoms))
            filtered_atoms = FrechetCellFilter(atoms, mask=[True] * 3 + [False] * 3)
        else:
            filtered_atoms = FrechetCellFilter(atoms)

        optimizer = optim_cls(
            filtered_atoms,
            logfile=f"{out_dir}/relax_{idx}.log",
        )

        pre_sym_group = check_symmetry(atoms, symprec).number
        optimizer.run(fmax=force_max, steps=max_steps)
        post_sym_group = check_symmetry(atoms, symprec).number

        reached_max_steps = optimizer.nsteps >= max_steps
        if reached_max_steps:
            print(f"Material {mat_id=} reached {max_steps=} during relaxation")

        # Maximum residual stress component
        max_stress = atoms.get_stress().reshape((2, 3), order="C").max(axis=1)

        atoms.calc = None
        atoms.constraints = None
        atoms.info = init_info | atoms.info

        relax_dict = {
            "max_stress": max_stress,
            "reached_max_steps": reached_max_steps,
            "broken_symmetry": pre_sym_group != post_sym_group,
        }

        if not ignore_broken_symm and pre_sym_group != post_sym_group:
            raise ValueError(
                f"Symmetry group changed from {pre_sym_group} to {post_sym_group}"
            )

    except Exception as exc:
        warnings.warn(f"Failed to relax {formula=}, {mat_id=}: {exc!r}", stacklevel=2)
        traceback.print_exc()
        info_dict["errors"].append(f"RelaxError: {exc!r}")
        info_dict["error_traceback"].append(traceback.format_exc())
        kappa_results[mat_id] = info_dict | relax_dict
        continue

    # Force constants calculation
    try:
        unit_cell = PhonopyAtoms(
            atoms.symbols, cell=atoms.cell, positions=atoms.positions
        )
        ph3 = Phono3py(
            unitcell=unit_cell,
            supercell_matrix=atoms.info["fc3_supercell"],
            phonon_supercell_matrix=atoms.info["fc2_supercell"],
            primitive_matrix="auto",
            symprec=symprec,
        )
        ph3.mesh_numbers = atoms.info["q_point_mesh"]

        # Generate displacements in both positive and negative direction even if
        # symmetrically equivalent (different from other models!)
        ph3.generate_displacements(
            distance=displacement_distance, is_plusminus=is_plusminus
        )
        # Calculate force constants and frequencies
        ph3, fc2_set, freqs = ltc.get_fc2_and_freqs(
            ph3, calculator=calc, pbar_kwargs={"disable": True}
        )

        # Check for imaginary frequencies
        has_imaginary_freqs = phonons.check_imaginary_freqs(freqs)
        freqs_dict = {
            Key.has_imag_ph_modes: has_imaginary_freqs,
            Key.ph_freqs: freqs,
        }

        # Determine if we should continue calculating conductivity
        continue_computing_conductivity = (
            not has_imaginary_freqs or ignore_imaginary_freqs
        )

        if continue_computing_conductivity:
            fc3_set = ltc.calculate_fc3_set(
                ph3,
                calculator=calc,
                pbar_kwargs={"position": idx},
            )
            ph3.produce_fc3(symmetrize_fc3r=True)
        else:
            fc3_set = []

        if save_forces:
            force_results_item = {"fc2_set": fc2_set, "fc3_set": fc3_set}

        if not continue_computing_conductivity:
            kappa_results[mat_id] = info_dict | relax_dict | freqs_dict
            warnings.warn(
                f"Skipping {mat_id} due to imaginary frequencies", stacklevel=2
            )
            if force_results_item is not None:
                force_results[mat_id] = force_results_item
            continue

    except Exception as exc:
        warnings.warn(f"Failed to calculate force sets {mat_id}: {exc!r}", stacklevel=2)
        traceback.print_exc()
        info_dict["errors"].append(f"ForceConstantError: {exc!r}")
        info_dict["error_traceback"].append(traceback.format_exc())
        kappa_results[mat_id] = info_dict | relax_dict
        continue

    # Thermal conductivity calculation
    try:
        ph3, kappa_dict, _cond = ltc.calculate_conductivity(
            ph3, temperatures=temperatures
        )
        kappa_results[mat_id] = info_dict | relax_dict | freqs_dict | kappa_dict
        if force_results_item is not None:
            force_results[mat_id] = force_results_item
    except Exception as exc:
        warnings.warn(
            f"Failed to calculate conductivity {mat_id}: {exc!r}", stacklevel=2
        )
        traceback.print_exc()
        info_dict["errors"].append(f"ConductivityError: {exc!r}")
        info_dict["error_traceback"].append(traceback.format_exc())
        kappa_results[mat_id] = info_dict | relax_dict | freqs_dict
        if force_results_item is not None:
            force_results[mat_id] = force_results_item

# Save results
df_kappa = pd.DataFrame(kappa_results).T
df_kappa.index.name = Key.mat_id
df_kappa.reset_index().to_json(out_path)
print(f"Saved kappa results to {out_path}")

if save_forces and force_results:
    df_force = pd.DataFrame(force_results).T
    df_force = pd.concat([df_kappa[[]].copy(), df_force], axis=1)
    df_force.index.name = Key.mat_id
    df_force.reset_index().to_json(force_sets_path)
    print(f"Saved force sets to {force_sets_path}")

try:
    print("Computing metrics against reference data...")
    df_dft = pd.read_json(DataFiles.phonondb_pbe_103_kappa_no_nac.path).set_index("mp_id")
    df_ml_metrics = calc_kappa_metrics_from_dfs(df_kappa, df_dft)
    # Compute and print summary metrics
    kappa_sre = df_ml_metrics[Key.sre].mean()
    kappa_srme = df_ml_metrics[Key.srme].mean()
    print(f"{kappa_sre=:.4f}")
    print(f"{kappa_srme=:.4f}")

    df_ml_metrics.to_json(f"{out_dir}/metrics.json.gz")
    print(f"Saved metrics to {out_dir}/metrics.json.gz")
except Exception as exc:
    warnings.warn(f"Failed to calculate metrics: {exc!r}", stacklevel=2)
    traceback.print_exc()
