from greenWTE import to_cpu, xp
from greenWTE.base import Material
from greenWTE.iterative import IterativeWTESolver
from greenWTE.sources import source_term_gradT
from greenWTE.tests.defaults import SI_INPUT_PATH
import matplotlib.pyplot as plt

K_FT = 10 ** 2.5  # spatial frequency in rad/m
omegas = xp.logspace(7, 14, 20)  # temporal frequencies in rad/s
material = Material.from_phono3py(SI_INPUT_PATH, temperature=300)
source = source_term_gradT(
    K_FT,
    material.velocity_operator,
    material.phonon_freq,
    material.linewidth,
    material.heat_capacity,
    material.volume,
)
solver = IterativeWTESolver(
    omegas,
    K_FT,
    material,
    source,
    source_type="gradient",
    outer_solver="none",
    print_progress=True,
)
solver.run()

f, ax = plt.subplots()

ax.set_xlim(to_cpu(omegas[0]), to_cpu(omegas[-1]))
ax.plot(to_cpu(omegas), to_cpu(xp.real(solver.kappa_p)), "o-", label="$\Re(\kappa_\mathrm{P})$", mec="k")
ax.plot(to_cpu(omegas), to_cpu(xp.imag(solver.kappa_p)), "o-", label="$\Im(\kappa_\mathrm{P})$", mec="k")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("temporal frequency [rad/s]")
ax.set_ylabel("thermal conductivity [W/mK]")
ax.legend()

f.tight_layout()
# plt.show()